Skip to main content

SAM

SAM、つまり Segment Anything Model は、image と prompt を受け取り、対応する mask を返す segmentation model です。Meta によって 2023 年に提案され、promptable segmentation という枠組みを定着させました。

全体構成

SAM は、大きく分けて三つの component から構成されます。

  1. Image encoder: image を一度だけ重い backbone (ViT) に通して、embedding を作ります。
  2. Prompt encoder: point、box、rough mask などの prompt を token に encode します。
  3. Mask decoder: image embedding と prompt token を組み合わせ、軽量に mask を予測します。

この分離が重要です。Image encoder は重いですが、同じ image に対しては embedding を再利用できるため、prompt を変えながら mask を高速に取り直せます。

Prompt の種類

SAM は次のような prompt を受け付けます。

Prompt役割
Point (foreground)対象 object の内側にある点
Point (background)対象から除外したい領域の点
Box対象を囲む bounding box
Mask粗い mask (前回の出力を refine するときに使える)

Interactive segmentation tool では、user が画像を click する操作が point prompt に対応します。

Ambiguity-aware design

「この点を含む object を segment して」という prompt は、しばしば曖昧です。例えば人物を click すると、

  • 人全体
  • 着ている服
  • 服の柄

のどれも正解になりえます。SAM はこれに対して、1 prompt から複数 mask を返す設計を採用しています。

User や下流 task は、信頼度 score の高いものや、所望の粒度のものを選びます。これは曖昧さを model 側で吸収する、よくできた design です。

Zero-shot generalization

SAM は、特定 dataset 向けに fine-tune されたものではなく、SA-1B という大規模 dataset 上で training された結果として、多様な image domain で zero-shot に動作します。Medical image、satellite image、artwork のような、training data に直接含まれない domain でも、ある程度の segmentation が得られます。

これは、generative model の世界での「foundation model」と同じ位置づけの考え方です。

何が SAM の貢献か

整理すると、SAM の貢献は次の三点に要約できます。

  • Segmentation を promptable な task として再定義した
  • 大規模な promptable segmentation dataset (SA-1B) を構築・公開した
  • Image embedding を一度だけ計算する interactive-friendly な architecture を作った

これにより、segmentation は「特定 class を切り出す classifier」から「任意の object に対する universal な mask producer」へと位置づけが変わりました。

数式で見る promptable segmentation

SAM は、image II と prompt pp(point、box、mask、text)を入力として、mask M^\hat{M} を出力する関数として書けます。

M^=fθ(I,p)\hat{M}=f_\theta(I,p)

Mask は pixel ごとの確率として表されます。

M^(x)[0,1]\hat{M}(\mathbf{x})\in[0,1]

Training では、focal loss と Dice loss を組み合わせます。Focal loss は、easy negative の重みを下げて難しい pixel に学習を集中させます。

Lfocal=1Ωxαt(1pt)γlogpt\mathcal{L}_{\mathrm{focal}} =-\frac{1}{|\Omega|}\sum_{\mathbf{x}}\alpha_t(1-p_t)^\gamma\log p_t

Dice loss は、領域全体の重なりを直接最適化します。

LDice=12xM^(x)M(x)xM^(x)+xM(x)\mathcal{L}_{\mathrm{Dice}} =1-\frac{2\sum_{\mathbf{x}}\hat{M}(\mathbf{x})M(\mathbf{x})} {\sum_{\mathbf{x}}\hat{M}(\mathbf{x})+\sum_{\mathbf{x}}M(\mathbf{x})}

総合 loss は次のように書けます。

L=λfLfocal+λdLDice\mathcal{L}=\lambda_f\mathcal{L}_{\mathrm{focal}}+\lambda_d\mathcal{L}_{\mathrm{Dice}}

この式の気持ちは、「pixel 単位の難しさに合わせて焦点を絞る focal loss と、領域全体の重なりを直接合わせる Dice loss を組み合わせて、boundary も領域も同時に学ぶ」というものです。

数式で見る ambiguous prompt の扱い

一つの prompt から、複数の妥当な mask が考えられる場合があります。たとえば人物の上に点を打つと、髪、上半身、人物全体のいずれも妥当な解になりえます。SAM は、KK 個の mask 候補を同時に出し、それぞれに confidence ckc_k を予測します。

{(M^k,ck)}k=1K=fθ(I,p)\{(\hat{M}_k,c_k)\}_{k=1}^{K}=f_\theta(I,p)

Training では、KK 個の予測のうち、ground truth に最も近いものだけに loss を流します。

k=argminkL(M^k,M),Ltotal=L(M^k,M)+Lconf(ck,IoUk)k^*=\arg\min_k \mathcal{L}(\hat{M}_k,M^*), \qquad \mathcal{L}_{\mathrm{total}}=\mathcal{L}(\hat{M}_{k^*},M^*)+\mathcal{L}_{\mathrm{conf}}(c_{k^*},\mathrm{IoU}_{k^*})

この式の気持ちは、「曖昧な prompt に対して無理に一つの正解を押し付けるのではなく、複数の候補を出し、最も近いものを伸ばし、confidence で順位付けする」というものです。

関連ページ

主なソース