Skip to main content

I-JEPA

I-JEPA は、image を対象にした Joint Embedding Predictive Architecture です。Image の一部を context として見せ、隠された target block の representation を embedding space で予測します。

Masked image modeling との違い

I-JEPA は MAE のように image patch を mask しますが、target pixel を復元するわけではありません。

Pixel reconstruction ではなく feature prediction を行うため、texture より semantic な structure を捉えやすくなります。

Context block と target block

I-JEPA では、context block と target block の位置をランダムに選びます。Context から target を予測するには、object layout や scene structure を理解する必要があります。

なぜ semantic representation が得られるのか

Target encoder が抽象的な representation を出すため、predictor は low-level pixel を完全に復元する必要がありません。これにより、class や object-level な情報が embedding に入りやすくなります。

数式で見る I-JEPA

I-JEPA では、画像から context block と複数の target block を切り出します。Context encoder の出力を zc=fθ(xc)z_c=f_\theta(x_c)、target block jj の中の位置 ii に対する teacher representation を zt,i=fθˉ(xt,i)z_{t,i}=f_{\bar{\theta}}(x_{t,i}) とします。Predictor gϕg_\phi は、context representation と target の位置情報 mi,jm_{i,j} から、target representation を予測します。

LI-JEPA=j=1MiBjgϕ(zc,mi,j)sg(zt,i)22\mathcal{L}_{\mathrm{I\text{-}JEPA}} = \sum_{j=1}^{M}\sum_{i\in\mathcal{B}_j} \left\| g_\phi(z_c,m_{i,j}) - \mathrm{sg}(z_{t,i}) \right\|_2^2

ここで、Bj\mathcal{B}_jjj 番目の target block に含まれる patch index の集合であり、MM は target block の数です。mi,jm_{i,j} は「どの位置の target を予測しているか」を predictor に伝えるための mask token または位置埋め込みです。

この式の気持ちは、「隠れた patch の RGB 値を復元するのではなく、target encoder が見たときの抽象表現を当てる」というものです。そのため、texture や細かな色の揺らぎを完全に再現する必要がなく、object layout や semantic なまとまりを表現に入れやすくなります。sg()\mathrm{sg}(\cdot) により target 側を固定しているので、student と teacher が同時に同じ方向へ崩れることも抑えられます。

関連ページ

主なソース