Skip to main content

Wasserstein GAN

Wasserstein GAN(以下 WGAN)は、vanilla GAN の Jensen-Shannon Divergence に基づく見方を、Wasserstein Distance に基づく見方へ置き換えることで、GAN training を安定させようとする方法です。

なぜ Wasserstein Distance を使うのか

Vanilla GAN では、Discriminator が optimal であると仮定すると、Generator は real data distribution prp_r と generated distribution pgp_g の JS Divergence を小さくするように training されます。

しかし、高次元空間では prp_rpgp_g の support が重なりにくく、JS Divergence は有用な gradient を与えにくくなります。

Wasserstein Distance は、二つの distribution が overlap していない場合でも、距離を滑らかに表せます。そのため、Generator にとって、どの方向へ動けば real distribution に近づくのかが学びやすくなります。

Kantorovich-Rubinstein duality

Wasserstein Distance を直接計算するには、Π(pr,pg)\Pi(p_r, p_g) に含まれるすべての joint distribution を探索する必要があります。これは現実的ではありません。

そこで WGAN では、Kantorovich-Rubinstein duality によって、Wasserstein Distance を次の形に書き換えます。

W(pr,pg)=1KsupfLKExpr[f(x)]Expg[f(x)]W(p_r, p_g) = \frac{1}{K}\sup_{\|f\|_L \leq K} \mathbb{E}_{x \sim p_r}\left[f(x)\right] - \mathbb{E}_{x \sim p_g}\left[f(x)\right]

ここで、sup\sup は supremum、つまり上限を意味します。ffKK-Lipschitz continuous な function である必要があります。

Lipschitz constraint

Function f:RRf : \mathbb{R} \to \mathbb{R}KK-Lipschitz continuous であるとは、ある実定数 K0K \geq 0 が存在して、すべての x1,x2Rx_1, x_2 \in \mathbb{R} に対して次の不等式が成り立つことを意味します。

f(x1)f(x2)Kx1x2|f(x_1) - f(x_2)| \leq K |x_1 - x_2|

この条件は、function の変化が急激になりすぎないことを表します。WGAN では、Critic が学習する function がこの条件を満たす必要があります。

WGAN における Critic

WGAN では、従来の Discriminator は、real か fake かを直接分類する binary classifier ではなくなります。代わりに、Wasserstein Distance を推定するための function fwf_w を学習する Critic として使われます。

WGAN の objective は、概念的には次の形になります。

L(pr,pg)=W(pr,pg)=maxwWExpr[fw(x)]Ezpz(z)[fw(gθ(z))]L(p_r, p_g) = W(p_r, p_g) = \max_{w \in W}\mathbb{E}_{x \sim p_r}\left[f_w(x)\right] - \mathbb{E}_{z \sim p_z(z)}\left[f_w(g_\theta(z))\right]

Critic は、この差を大きくするように training されます。Generator は、この差を小さくするように training されます。training が進んで Wasserstein Distance が小さくなれば、generated distribution pgp_g は real distribution prp_r に近づいていると解釈できます。

Weight clipping

WGAN では、Critic が KK-Lipschitz continuous であることを training の間も維持する必要があります。Original WGAN では、実用的な方法として weight clipping が使われます。Critic の gradient update のたびに、weight ww を小さな範囲、たとえば [0.01,0.01][-0.01, 0.01] に収めます。

wclip(w,c,c)w \leftarrow \mathrm{clip}(w, -c, c)

これによって parameter space を compact に保ち、Critic の function が過度に急激に変化することを防ごうとします。

Algorithm

WGAN algorithm

画像出典: Lilian Weng, “From GAN to WGAN”。Original WGAN の training algorithm では、Critic を複数回更新し、そのたびに weight clipping を行います。

Vanilla GAN からの主な変更点

WGAN では、vanilla GAN と比べて次のような変更があります。

  • Discriminator は binary classifier ではなく、Wasserstein Distance を推定する Critic として扱われます。
  • Loss function では logarithm を使わず、Wasserstein Distance に由来する objective を使います。
  • Critic の gradient update のたびに、weight を [c,c][-c, c] の範囲へ clipping します。
  • Original WGAN では、momentum を持つ optimizer の代わりに RMSProp が推奨されています。

限界と WGAN-GP

WGAN は vanilla GAN の問題を改善しますが、完全な解決策ではありません。特に、weight clipping は Lipschitz constraint を満たすための粗い方法です。Clipping の範囲が広すぎると training が不安定になり、狭すぎると Critic の表現力が落ちて gradient が弱くなります。

この問題を改善する方法として、WGAN-GP では weight clipping の代わりに gradient penalty が使われます。Gradient penalty は、Critic の gradient norm を制御することで、Lipschitz constraint をより自然に扱おうとします。

数式で見る Wasserstein distance

WGAN は、Jensen-Shannon divergence の代わりに Wasserstein-1 distance を最小化します。Kantorovich-Rubinstein duality により、Wasserstein-1 distance は次のように書けます。

W(pr,pg)=supfL1Expr[f(x)]Expg[f(x)]W(p_r,p_g)=\sup_{\|f\|_L\le 1} \mathbb{E}_{x\sim p_r}[f(x)]- \mathbb{E}_{x\sim p_g}[f(x)]

ここで、ff は 1-Lipschitz な critic です。この式の気持ちは、「real sample には高い score を、generated sample には低い score を付ける critic を探し、その差を generator が小さくする」というものです。

WGAN-GP では、critic が 1-Lipschitz に近くなるように gradient penalty を入れます。

LGP=λEx^[(x^D(x^)21)2]\mathcal{L}_{GP}=\lambda\mathbb{E}_{\hat{x}} \left[(\|\nabla_{\hat{x}}D(\hat{x})\|_2-1)^2\right]

この式は、「critic の勾配 norm が 1 から大きく外れないようにする」制約です。Weight clipping よりも滑らかに Lipschitz 条件を促せるため、training が安定しやすくなります。

関連ページ