Skip to main content

JEPA Overview

JEPA、つまり Joint Embedding Predictive Architecture は、入力の一部から別の一部を pixel ではなく embedding space で予測する self-supervised learning framework です。Yann LeCun が提唱する energy-based / predictive learning の流れにあり、world model の representation learning と強く関係します。

JEPA architecture map

自作概念図。JEPA は context の embedding から target の embedding を予測し、pixel-level reconstruction を避けます。

基本 idea

JEPA では、入力を context と target に分けます。Context は model に見せる部分、target は隠す部分または未来の部分です。

重要なのは、target を pixel そのものとして復元するのではなく、target encoder が出した abstract representation を予測することです。

なぜ pixel reconstruction を避けるのか

Pixel-level reconstruction は、background texture や照明の細かい変化など、意味理解に不要な detail まで再現しようとします。JEPA は、embedding space で予測することで、semantic / physical に重要な情報に focus します。

Approach予測対象特徴
Masked AutoencoderPixels / patchesLow-level detail まで復元しやすい
Contrastive learningPositive / negative pairNegative sample 設計が必要
JEPATarget embeddingAbstract な representation を予測

Collapse をどう避けるか

Embedding prediction では、すべての input を同じ vector にしてしまう collapse が問題になります。JEPA 系では、target encoder を stop-gradient / EMA で更新したり、predictor を入れたりして、collapse を防ぎます。

World Model との関係

JEPA は、環境の未来を pixel で生成するのではなく、future representation を予測する world model として解釈できます。

これが World Models との接点です。

数式で見る JEPA の予測目的

JEPA は、入力の一部から別の部分の representation を予測します。Context encoder を fθf_\theta、target encoder を fθˉf_{\bar{\theta}}、predictor を gϕg_\phi とします。Context view を xcx_c、target view を xtx_t とすると、典型的な loss は次のように書けます。

LJEPA=gϕ(fθ(xc))sg(fθˉ(xt))22\mathcal{L}_{\mathrm{JEPA}} =\left\|g_\phi(f_\theta(x_c))-\mathrm{sg}(f_{\bar{\theta}}(x_t))\right\|_2^2

ここで、sg\mathrm{sg} は stop-gradient です。この式の気持ちは、「pixel をそのまま復元するのではなく、抽象化された target representation を予測する」というものです。Pixel-level reconstruction よりも、semantic な情報を学びやすいことが期待されます。

関連ページ

主なソース