Skip to main content

Contrastive Learning Overview

Contrastive Learning は、同じ image の異なる augmentation 同士は近く、異なる image 同士は遠くなるように representation を学ぶ self-supervised learning です。Vision foundation model の前半期 (2019-2021) を牽引しました。

基本 idea

InfoNCE loss が代表的です。

数式で見る InfoNCE と alignment-uniformity

Contrastive learning では、anchor view ii と positive view jj を近づけ、同じ batch または queue にある他の view を相対的に遠ざけます。Cosine similarity を sim(,)\mathrm{sim}(\cdot,\cdot)、temperature を τ\tau とすると、InfoNCE loss は次のように書けます。

Li,j=logexp(sim(zi,zj)/τ)kiexp(sim(zi,zk)/τ)\mathcal{L}_{i,j} = -\log \frac{\exp(\mathrm{sim}(z_i,z_j)/\tau)} {\sum_{k\ne i}\exp(\mathrm{sim}(z_i,z_k)/\tau)}

ここで、ziz_izjz_j は同じ画像から作った二つの view の representation です。分子は positive pair を近づける力を表し、分母は batch 内の他の view を含む正規化項です。この式の気持ちは、「同じ画像に由来する view は同じ意味を持つので近くに置き、別の画像に由来する view は混ざらないように表現空間を広げる」というものです。

この性質は、alignment と uniformity という二つの観点でも整理できます。

Lalign=E(x,x+)f(x)f(x+)22\mathcal{L}_{\mathrm{align}} = \mathbb{E}_{(x,x^+)} \left\|f(x)-f(x^+)\right\|_2^2 Luniform=logE(x,y)exp ⁣(tf(x)f(y)22)\mathcal{L}_{\mathrm{uniform}} = \log \mathbb{E}_{(x,y)} \exp\!\left(-t\left\|f(x)-f(y)\right\|_2^2\right)

Lalign\mathcal{L}_{\mathrm{align}} は positive pair をどれだけ近づけるかを表し、Luniform\mathcal{L}_{\mathrm{uniform}} は異なる sample の representation が hypersphere 上でどれだけ一様に散らばるかを表します。Contrastive learning の難しさは、この二つのバランスにあります。Alignment だけを強めると collapse に近づき、uniformity だけを強めると意味的に近い sample まで離れすぎます。

主要 family

種別代表例特徴
ContrastiveSimCLR、MoCoNegative sample を明示的に使う
Non-contrastiveBYOL、SimSiamNegative なしで collapse を回避
ClusteringSwAVOnline clustering を contrastive 化

詳細ページ

ページ内容
SimCLRAugmentation と large batch の contrastive
MoCoMomentum encoder + memory queue
BYOL and SimSiamNegative 不要の self-distillation 系

CLIP との関係

CLIP は image-text contrastive ですが、image-image contrastive と同じ family に属します。

JEPA との違い

JEPA は contrastive と違い、negative や augmentation invariance を直接使わず、target representation を予測します。ただし、JEPA でも collapse 回避のために stop-gradient / EMA が使われるため、設計思想は self-distillation 系と近接しています。

関連ページ

主なソース