LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels 论文翻译
摘要联合嵌入预测架构JEPA为在紧凑隐空间中学习世界模型提供了极具潜力的框架但现有方法仍不稳定需依赖复杂的多损失项、指数移动平均、预训练编码器或辅助监督来避免表征坍缩。本文提出LeWorldModelLeWM首个仅用两项损失即可从原始像素端到端稳定训练的 JEPA下一时刻嵌入预测损失以及强制隐嵌入服从高斯分布的正则项。与现有唯一端到端方案相比可训练损失超参数从 6 个降至 1 个。LeWM 仅 1500 万参数单 GPU 数小时即可完成训练规划速度比基于大模型的世界模型快48 倍且在各类 2D/3D 控制任务上性能相当。除控制任务外通过物理量探测实验证明LeWM 的隐空间能编码有意义的物理结构意外值评估则验证模型可可靠检测物理上不合理的事件。1 引言人工智能的核心目标之一是让智能体通过单一统一学习范式习得多任务、多环境技能 —— 直接从环境感官输入学习无需人工设计状态表征或领域特定校准。视觉输入尤为适合相机成本低、易扩展从像素学习可实现从原始感官到动作的全端到端训练。世界模型World Models, WMs是一类强大方法可学习预测动作在环境中的后果。成功的世界模型能让智能体仅在 “想象空间” 中规划与自我提升这在离线学习场景中尤为重要智能体无需与环境交互仅从固定数据集学习利用模型生成虚拟经验并评估反事实动作序列。近期流行的世界模型学习方法是联合嵌入预测架构JEPA。JEPA 不试图建模环境所有细节而是聚焦预测未来状态所需的关键特征。具体而言JEPA 将观测编码为紧凑低维隐空间并通过预测未来观测的隐表征建模时序动态。但 JEPA 概念虽简洁现有方法极易坍缩模型将所有输入映射为几乎相同的表征以简单满足时序预测目标导致表征失效。因此防止坍缩是 JEPA 训练的核心挑战。诸多工作提出解决方案但通常依赖启发式正则、多目标损失、外部信息或预训练编码器等架构简化。实际中这些策略常引入额外不稳定性或显著提升训练复杂度。为突破上述局限本文提出 LeWM首个无需启发式、原理清晰且简洁从原始像素端到端稳定学习 JEPA 的方法。LeWM 可在单 GPU 上训练降低研究门槛。本文在 2D/3D 环境的操作、导航、运动任务上全面评估 LeWM并通过针对性探测与意外值量化实验检验其隐空间的直观物理理解能力。本文核心贡献提出一种端到端 JEPA 方法可从原始像素在单 GPU 上学习隐世界模型依赖简洁稳定的双项目标在不同架构与超参数下保持鲁棒支持高效对数时间超参数搜索。LeWM 以 1500 万参数的紧凑模型在各类 2D/3D 控制任务上取得强劲性能超越现有端到端 JEPA 方法性能与基于大模型的世界模型相当成本大幅降低规划速度最高提升48 倍。通过物理量探测与预期违背测试评估隐空间的物理理解能力验证模型可检测非物理轨迹。2 相关工作世界模型世界模型旨在从数据中学习环境动态的预测模型让智能体在想象中推理未来状态。主流世界模型包含生成式方法在像素空间显式建模环境动态基于过去状态与动作生成未来观测可作为学习型模拟器已成功用于 Minecraft、CS、Crafter 等游戏环境提升强化学习策略的样本效率。但多数生成式世界模型需奖励信号联合建模动态与价值相关信息。本文聚焦无奖励场景与 JEPA 系列工作一致仅从观测数据学习通用、任务无关的世界模型不依赖奖励监督。传统生成式世界模型Generative World Models 通常是面向强化学习RL设计的。 它们不只预测“下一帧画面会是什么样”dynamics还必须同时预测奖励reward 和 结束信号done。 为什么因为它们的最终目标是“训练一个能最大化累计奖励的策略policy”。 所以模型的输入是 (当前画面, 当前动作)输出是 (下一画面, 下一奖励, 是否结束)。 典型例子DreamerV3、IRIS、DIAMOND 等。它们在训练时会用奖励信号来更新价值函数和策略所以数据集里必须包含奖励标签。LeWorldModelLeWM走的完全是另一条路JEPA 路线 完全不依赖奖励只用 (,) 这两个东西训练。 它的目标是学一个通用的、任务无关的世界模型而不是直接学一个“为了某个具体任务拿高分”的策略。为什么 LeWM 可以没有奖励因为它只做下一状态预测next-embedding prediction编码器把每一帧画面变成一个低维向量。预测器根据动作预测。损失函数只有两项①预测误差MSE||ẑ_{t1} - z_{t1}||² ②SIGReg 正则项让 latent 分布接近高斯避免 collapse。没有奖励它是怎么“工作”和“得到反馈”的训练时完全自监督self-supervised。模型自己看数据里的 ()用“预测得准不准”作为唯一信号来更新参数。SIGReg 再强迫 latent 不能全部坍缩成一个点。使用时规划/控制 给模型一个目标画面比如你想让机器人把积木推到某个位置模型把也编码成。 然后在 latent 空间里用 CEMCross-Entropy Method优化动作序列让预测的最终状态尽量接近。 → 这里不需要任何奖励函数只需要“目标画面长什么样”就够了goal-conditioned planning。JEPAJEPA 在紧凑低维隐空间中预测系统动态演化。自 LeCun 提出后JEPA 主要沿两条路线发展自监督表征学习预测掩码输入块的隐嵌入如 I-JEPA图像、V-JEPA视频、Echo-JEPA/Brain-JEPA医疗数据。这类方法通常用 ** 目标编码器指数移动平均EMA与停止梯度SG** 稳定训练、防止坍缩但 EMA 与 SG 缺乏理论支撑通常不对应明确定义的目标最小化。基于动作的隐世界建模部分方法依赖预训练编码器获取表征虽避免坍缩但表征表达力受限于预训练编码器。PLDM 则用 VICReg 加额外正则项端到端学习表征但存在训练不稳定、可扩展性受限问题。本文提出稳定训练方案直接从像素端到端训练 JEPA仅用两项损失—— 未来嵌入预测目标与强制嵌入服从高斯分布的正则项。基于隐动态的规划世界模型开创性地从高维观测的紧凑隐表征直接学习策略。近期工作则在测试时直接在隐空间做规划MPC用世界模型在线预测候选动作序列结果并迭代优化模型保留在控制回路中支持自适应决策但计算需求更高。3 方法LeWorldModel本节介绍LeWorldModelLeWM。首先描述从离线数据中学习隐式世界模型的简化训练流程包括数据集、模型架构和训练目标随后说明如何通过 ** 模型预测控制MPC** 在隐空间进行规划从而利用训练好的模型做决策。3.1 学习隐式世界模型离线数据集本文采用完全离线、无奖励的学习设置。LeWorldModel 仅从未标注的观测与动作轨迹训练不使用奖励信号或任务指定信息。这一设定与 JEPA 系列工作一致目标是从观测数据中学习通用、与任务无关的世界模型而非针对特定任务优化行为。完全离线fully offline的含义训练阶段完全不和真实环境交互。数据集是事先收集好的固定轨迹trajectories只包含 ()没有奖励也没有任务标签。训练完以后模型就固定了参数不再更新。后面做规划planning或控制时也不用再去环境里采样新数据全部在模型的 latent 空间里“想象”完成imagination / latent planning。训练数据由长度为 T 的轨迹构成包含原始像素观测 o1:T 与对应的动作 a1:T。轨迹从任意行为策略中离线采集无需满足最优性只需充分覆盖环境动态即可。模型架构LeWM 由两个核心模块构成编码器Encoder与预测器Predictor。编码器将单帧观测 ot 映射为紧凑、低维的隐式表示 zt。预测器在隐空间建模环境动态给定当前隐嵌入 zt 与动作 at预测下一帧观测的嵌入。计算公式编码器实现为视觉 TransformerViT默认采用 tiny 配置约 500 万参数patch 大小 1412 层、3 个注意力头隐藏维度 192。观测嵌入 zt 取自最后一层的[CLS]token再经过单层 MLP 批归一化做投影。这一步是必要的因为 ViT 最后一层使用层归一化会阻碍防坍缩目标的有效优化。预测器为 6 层 Transformer16 个注意力头dropout 率 10%约 1000 万参数。动作通过自适应层归一化AdaLN融入每一层AdaLN 参数初始化为 0以稳定训练并让动作条件逐步影响预测器学习。预测器接收 N 帧历史表示带时序因果掩码自回归预测下一帧表示。预测器后同样接有与编码器结构一致的投影网络。世界模型的所有组件联合端到端学习。训练目标学习目标是得到对未来预测有用、能建模环境动态的隐式表示。LeWM 的训练损失由两项相加构成预测损失与正则化损失。预测损失教师强制计算相邻时刻预测嵌入与真实嵌入的均方误差预测损失会激励编码器学习对预测器 “可预测” 的表示。但仅用该损失会导致表征坍缩编码器把所有输入映射成相同常量。SIGReg 防坍缩正则项为避免坍缩、提升嵌入空间的特征多样性采用SIGRegSketched-Isotropic-Gaussian Regularizer强制隐嵌入服从各向同性高斯分布。高维空间直接检验正态性很困难SIGReg 通过将嵌入投影到 M 个随机单位方向对每一维投影做单变量Epps–Pulley 正态性检验再取平均。根据Cramér–Wold 定理匹配所有一维边缘分布等价于匹配完整联合分布。Z 一堆隐向量把它们随机投影到很多个一维直线上就像从不同角度看对每个方向检查是不是像高斯分布正态分布把所有方向的结果取平均总结强迫模型的输出像正态分布不要所有输出都变成一样的数。这就是防止坍缩的核心最终 LeWM 训练目标该方法仅引入两个训练超参数SIGReg 的随机投影数量 M、正则权重 λ。默认设置 M1024、λ0.1。实验表明 M 对下游性能几乎无影响因此只有 λ 需要调参可通过二分搜索高效完成对数复杂度。LeWM不使用停止梯度、指数移动平均EMA或其他稳定启发式所有损失全程回传梯度所有参数联合端到端优化流程简洁、易实现。3.2 隐空间规划推理阶段在世界模型的隐空间中做轨迹优化。给定初始观测 o1随机初始化候选动作序列迭代滚动预测隐状态至规划时域 H。模型按以下方式预测隐状态转移规划目标是最小化终端隐空间与目标的匹配误差其中 z^H 是滚动最后一步的预测隐状态zg 是目标观测 og 的编码。世界模型参数在规划时固定。这是一个有限时域最优控制问题本文使用 ** 交叉熵法CEM** 求解这是一种采样优化方法迭代选出最优策略并更新采样分布。为缓解自回归滚动带来的误差累积采用 ** 模型预测控制MPC** 策略仅执行前 K 个规划动作然后根据新观测重新规划推理 / 规划阶段latent planning对应 Figure 4给定当前画面 o₁ 和目标画面 o_g不需要奖励。编码z₁ enc(o₁)z_g enc(o_g)。用 CEMCross-Entropy Method在 latent 空间优化动作序列 a_{1:H}目标是让预测的最终状态尽量接近目标采用 MPCModel Predictive Control只执行前 K 步动作然后重新从新画面开始规划。为什么预测损失 SIGReg 就能逼出物理规律想象一下模型在训练时“看到”的数据以 Push-T 为例每一条轨迹都是真实物理模拟器生成的小球推 T 形积木时积木只会按照真实物理移动不能穿墙、会旋转、会滑行、有摩擦……。模型每次都看到 (当前画面 o_t, 动作 a_t, 下一画面 o_{t1})。训练过程可以简化成一句话“我要把压成把压成然后我要让 Predictor 能从 (z_t a_t) 准确猜出 z_{t1}。如果猜不对我就罚它”如果 Encoder 把位置、速度、角度、物体关系这些真正能决定下一帧的物理量 好好地编进 z 里Predictor 就很容易预测损失就小。如果 Encoder 把这些物理量扔掉、或者把所有画面都压成差不多一样的向量collapsePredictor 就完全预测不准损失就会很大。所以预测损失像一个“物理老师”它只奖励那些最容易预测未来的表示方式。 而最容易预测未来的表示恰恰就是低维的物理状态变量agent 位置、block 位置、block 角度、速度……。SIGReg 再加一道保险 它强迫所有 z 必须铺满一个均匀的高斯球面不能都挤成一个点。 这样就逼着 Encoder 必须把不同物理状态区分开不能偷懒。结果经过几万条真实物理轨迹的训练潜空间 z 就自动变成了一个压缩版的物理世界——里面天然包含了位置、速度、碰撞关系等信息。4 隐式规划性能4.1 规划评估设置环境我们在一系列多样化任务上对 LeWM 进行评估包括导航、运动规划与操作覆盖二维与三维环境所有环境均为连续动作空间。Two-Room简单的 2D 导航任务PushT经典 2D 机器人操作任务OGBench-Cube视觉更丰富的 3D 操作任务Reacher2D 双关节机械臂到达任务基线方法我们将 LeWM 与以下基线对比DINO-WM、PLDM当前最先进的基于 JEPA 的方法GCBC目标条件行为克隆策略GCIVL、GCIQL目标条件离线强化学习算法其中PLDM与本文设置最接近同样直接从像素观测端到端学习世界模型但依赖基于 VICReg 的七项训练目标导致训练不稳定、超参调优复杂。DINO-WM使用冻结的 DINOv2 作为特征编码器以缓解表征坍缩但并非端到端学习。为公平对比实验中 DINO-WM 不使用本体感受信息。所有方法在所有环境上均使用固定超参数。4.2 面向高效世界模型规划实验结果表明LeWM 在更具挑战性的规划任务上显著优于 PLDM在 PushT 任务上成功率高出 18%并与 DINO-WM 性能相当。值得注意的是在 PushT 上仅使用像素的 LeWM 甚至超过了带有额外本体感受信息的 DINO-WM证明 LeWM 能够有效捕捉任务相关的核心物理量。在最简单的 Two-Room 环境中LeWM 表现略差原因是该数据集多样性低、内在维度小SIGReg 正则在高维隐空间强制高斯分布可能导致隐表征结构不够理想。规划速度LeWM 的规划速度最高提升 48 倍完整规划可在1 秒内完成在各环境上表现稳定大幅接近实时控制的要求。4.3 面向世界模型的稳定训练消融实验我们对 LeWM 的关键设计进行消融分析SIGReg 内部参数随机投影数量 M、积分节点数对性能几乎无影响说明无需精细调优λ 是唯一有效超参数。超参搜索效率LeWM 仅需调 1 个超参可用二分搜索O (log n)高效完成PLDM 需调 6 个超参只能多项式搜索O (n⁶)。嵌入维度维度需足够大才能保证性能但超过阈值后快速饱和说明方法对编码器容量不敏感。编码器架构将默认 ViT 替换为 ResNet-18 仍能达到有竞争力的性能说明 LeWM 对视觉编码器不敏感。训练曲线LeWM两项损失目标收敛平滑、单调预测损失稳步下降SIGReg 损失在训练初期快速下降后趋于平稳隐分布快速逼近各向同性高斯。PLDM七项损失目标噪声大、非单调多个正则项相互竞争梯度难以平衡。结果充分体现 LeWM 的核心优势训练目标极简、过程高度稳定。5 量化 LeWM 中的物理理解能力本章将通过从隐表示中提取物理量、测量世界模型对物理变化的检测能力这两种方式评估 LeWM 隐空间所捕获的动态过程质量。Figure 7:Predictor rollouts on PushT and OGBench-Cube. We visualize decoded latent plans produced by LeWM given a context and an action sequence. Each rollout uses three image observations as context, which are encoded into latent representations. Conditioned on the action sequence, the predictor autoregressively generates future latent states in an open-loop manner. All predicted latents are decoded into images using a decoder that was not used during training. The resulting imagined rollouts closely match the real observations, demonstrating that the latent representation effectively captures the overall scene structure and essential environment dynamics. Some finer details, however, are not fully captured by LeWM; for instance, the angle of the end-effector in OGBench-Cube. Additional rollouts are provided in Fig. 11.Figure 8:Decoder visualization during training. As training progresses, the latent representation increasingly captures the information required to reconstruct the visual scene, even though no reconstruction loss is used during training. Early in training, the decoded images correspond to slow features, a phenomenon previously reported [49].Figure 9:Visualization of the latent space obtained with LeWM for the PushT environment. On the left, the grid of states is obtained by moving the agent and the block in the x-y plane. On the right, the embeddings of these states are visualized using a t-SNE.5.1 隐空间的物理结构物理量探测Probing physical quantities作为衡量物理理解的首要指标我们评估可以从 LeWM 的隐表示中恢复出哪些物理量。我们分别训练线性探针与非线性探针从给定的隐嵌入中预测目标物理量。在 Push‑T 环境上的结果表 1显示我们的方法持续优于 PLDM与 DINOv2 这类大规模预训练模型得到的表示相比性能相当DINO‑WM 在部分物理属性上的强劲表现可能来自其基座模型的大规模预训练约 1.24 亿张图像使其嵌入天然包含部分物理属性。隐空间解码为进一步评估隐表示所捕获的信息我们训练了一个解码器用于从单个隐嵌入192 维重建像素观测。尽管训练过程从未使用重建损失解码器依然能从学习到的表示中恢复出视觉场景证明低维紧凑的隐空间保留了足够的底层物理状态信息。隐空间可视化我们使用 t‑SNE 对隐空间结构做可视化。结果表明学习到的表示捕获了环境的空间结构在隐空间中保留了邻域关系与相对位置。时序隐路径拉直Temporal Latent Path Straightening受神经科学中的时序拉直假说启发我们测量训练过程中连续隐速度向量之间的余弦相似度。结果发现LeWM 的隐轨迹在训练中自然变得越来越平直且没有任何显式正则项鼓励这一行为即使 PLDM 使用了专门的时序平滑正则项LeWM 依然实现了更高的时序平直度这一现象是自发涌现的对下游规划任务有益。5.2 预期违背VoE评估框架另一种量化物理理解的方式是检测与所学世界模型相矛盾的事件的能力。本文采用发展心理学中常用的预期违背VoE范式评估模型是否会对违背物理规律的事件赋予更高的意外值surprise。我们用预测值与真实未来观测之间的偏差来量化意外值。实验在三个环境中进行Two‑Room、PushT、OGBench‑Cube。对每个环境我们引入两类扰动视觉扰动物体颜色突然改变物理扰动物体被瞬间传送到随机位置违背场景的物理连续性结果图 10清晰表明LeWM 对包含物理违背的帧 consistently 赋予显著更高的意外值对未扰动轨迹意外值保持低基线对颜色这类视觉扰动意外值上升微弱且不显著这说明模型对物理扰动远比对视觉扰动更敏感其隐空间真正学习到了环境的物理规则而非仅记忆视觉外观6 结论本文提出了LeWorldModelLeWM一种用于学习环境隐式世界模型的稳定端到端方法。LeWM 属于联合嵌入预测架构JEPA它使用编码器将图像观测映射到隐空间并通过预测器在嵌入空间中以动作为条件建模时序动态。在仅使用原始像素输入的多种连续控制环境中LeWM 在数据效率、规划速度、训练时长与稳定性上均超越了已有方法同时保持具有竞争力的最终任务性能。训练的稳定性与简洁性来自于显式地约束隐嵌入服从各向同性高斯分布以避免表征坍缩。总体而言LeWM 为现有隐式世界模型方法提供了一种可扩展的替代方案兼具原理清晰的训练动态以及可解释、可涌现的表示特性。局限与未来工作尽管取得了上述可喜成果本文仍存在一些局限也指明了重要的研究方向长时序规划能力有限当前基于隐式世界模型的规划仍局限于较短的时域。层级世界建模是解决长时序推理与规划的一个很有前景的方向。对数据多样性有依赖本方法仍依赖具有足够交互覆盖度的离线数据集这类数据的采集成本较高。在极低复杂度、内在维度很小的环境中SIGReg 正则在高维隐空间中强制匹配高斯先验会变得困难从而影响效果。在大规模、多样化的自然视频数据集上进行预训练有望提供更强的表示先验降低对领域特定数据的依赖。依赖动作标注现有端到端隐式世界模型需要动作标签来预测未来状态这类标注同样获取成本较高。一个有前景的方向是通过逆动态建模学习未来动作表示从而减少对显式动作标注的依赖。一些看完论文之后的小疑问这些任务的数据集里有没有前置设置的物理规则有但不是直接写在数据集文件里而是隐含在“生成数据集的模拟器”里。所有四个任务的数据集都是完全离线的里面只有两样东西原始像素画面序列对应的动作序列没有任何奖励、任务标签、物理参数、规则描述。但是这些数据不是随机生成的而是在一个固定物理模拟器里用策略heuristic 或 RL 策略跑出来的真实轨迹。