Causal-JEPA

Learning World Models through Object-Level Latent Interventions

1Brown University, 2New York University, 3Mila, 4Universite de Montreal
* Equal contribution

Summary

World models require robust relational understanding to support prediction, reasoning, and control. While object-centric representations provide a useful abstraction, they are not sufficient to capture interaction-dependent dynamics. We therefore propose C-JEPA, a simple and flexible object-centric world model that extends masked joint embedding prediction from image patches to object-centric representations. By applying object-level masking that requires an object's state to be inferred from other objects, C-JEPA induces latent interventions with counterfactual-like effects and prevents shortcut solutions, making interaction reasoning essential. Empirically, C-JEPA leads to consistent gains in visual question answering, with an absolute improvement of about 20% in counterfactual reasoning compared to the same architecture without object-level masking. On agent control tasks, C-JEPA enables substantially more efficient planning by using only 1% of the total latent input features required by patch-based world models, while achieving comparable performance. Finally, we provide a formal analysis demonstrating that object-level masking induces a causal inductive bias via latent interventions.

Methodology

C-JEPA Training Architecture

Figure 1. C-JEPA Training Pipeline. Selective object-level masking induces interaction reasoning by requiring masked history slots to be inferred from other objects.

C-JEPA treats object masking as a structured latent intervention. By masking an object's trajectory throughout the history window (except for a minimal identity anchor), the model is forced to rely on interactions with other entities to minimize prediction error.

Masked tokens are constructed using a linear projection \(\Phi\), an identity anchor \(z_{t_0}^i\), and learnable temporal embeddings \(e_{\tau}\): $$\overline{z}_{\tau}^{i} = \Phi(z_{t_{0}}^{i}) + e_{\tau}$$ The training objective minimizes the \(l_2\) distance between predicted and target masked latents across both history recovery and future prediction: $$\mathcal{L}_{mask} = \underbrace{\mathbb{E}[||\hat{Z}_{\tau} - Z_{\tau}||_{2}^{2} | \tau \le t]}_{\mathcal{L}_{history}} + \underbrace{\mathbb{E}[||\hat{Z}_{\tau} - Z_{\tau}||_{2}^{2} | \tau > t]}_{\mathcal{L}_{future}}$$

Experimental Results

1. Efficient Control (Push-T)

C-JEPA reconciles the efficiency of object-centric models with the performance of high-dimensional patch-based approaches. It matches the performance of DINO-WM while using only 1.02% of the total latent features, enabling 8x faster planning .

Model Input Features (Tokens) Success Rate (%)
DINO-WM (Patch-based) 196 x 384 91.33
OC-DINO-WM 6 x 128 60.67
OC-JEPA 6 x 128 76.00
C-JEPA (Ours) 6 x 128 (~1%) 88.67

2. Visual Reasoning (CLEVRER)

C-JEPA achieves the strongest performance without relying on pixel-level reconstruction. Removing reconstruction leads to severe degradation in baselines like SlotFormer, whereas C-JEPA's masking-based objective remains highly effective.

Model Recon. Avg. per Que. (%) Counterfactual per Que. (%)
SlotFormer Yes 79.44 47.29
SlotFormer (-) recon No 44.94 11.10
OCVP-Seq Yes 83.11 56.06
OCVP-Seq (-) recon No 80.09 43.00
OC-JEPA (Baseline) No 77.28 41.10
C-JEPA (Ours) No 83.88 60.19

*Results using SAVi encoder for a fair comparison with baselines.

Theoretical Analysis

We formally demonstrate that C-JEPA's objective induces a causal inductive bias. Masked history prediction forces the predictor to identify an influence neighborhood, the minimal sufficient subset of context variables required to recover an object's started. This process aligns the model's attention with interaction-stable relational structures, making interaction reasoning functionally necessary for minimizing training loss. More details are provided in the paper!