Skip to main content

GAN Training Techniques

GAN の training instability を緩和するために、さまざまな practical technique が提案されています。ここでは、Lilian Weng の記事で紹介されている代表的な方法を、知識要素として整理します。

Feature Matching

Feature Matching は、Generator の出力が real sample の統計量に近づくように training する方法です。Discriminator の中間 layer などから feature f(x)f(x) を取り出し、real sample の feature の期待値と generated sample の feature の期待値が近くなるように Generator を更新します。

Exprf(x)Ezpz(z)f(G(z))22\left\| \mathbb{E}_{x \sim p_r} f(x) - \mathbb{E}_{z \sim p_z(z)} f(G(z)) \right\|_2^2

Generator は、Discriminator の最終判定だけをだますのではなく、real data の feature statistics に近い sample を作るように促されます。

Minibatch Discrimination

Minibatch Discrimination は、Discriminator が sample を一つずつ独立に見るのではなく、同じ minibatch 内の sample 同士の関係も見るようにする方法です。

一つの minibatch の中で、sample のペア (xi,xj)(x_i, x_j) の近さ c(xi,xj)c(x_i, x_j) を近似的に計算し、ある data point についての要約を次のように作ります。

o(xi)=jc(xi,xj)o(x_i) = \sum_j c(x_i, x_j)

そして、o(xi)o(x_i) を Discriminator の入力に明示的に加えます。これによって、同じような sample ばかりが並んでいる状態を検出しやすくなり、Mode Collapse を緩和しやすくなります。

Historical Averaging

Historical Averaging は、model parameter が過去の平均から大きく離れすぎないように penalty を加える方法です。現在の parameter を Θ\Theta、過去の時刻 ii における parameter を Θi\Theta_i とすると、次のような項を loss に加えます。

Θ1ti=1tΘi2\left\| \Theta - \frac{1}{t}\sum_{i=1}^{t}\Theta_i \right\|^2

この penalty は、parameter が時間とともに激しく変化することを抑え、training の振動を緩和する効果が期待されます。

One-sided Label Smoothing

One-sided Label Smoothing は、Discriminator に与える label を 1100 の hard な値ではなく、0.90.90.10.1 のような少し柔らかい値にする方法です。

この方法は、Discriminator が過度に自信を持ちすぎることを防ぎます。GAN では Discriminator が強すぎると Generator の gradient が弱くなるため、label smoothing は training を安定させる方向に働きます。

Virtual Batch Normalization

Virtual Batch Normalization(VBN)は、各 sample を現在の minibatch の統計量だけで normalize するのではなく、固定された reference batch に基づいて normalize する方法です。Reference batch は training の最初に一度だけ選ばれ、その後は固定されます。

通常の Batch Normalization では、同じ sample であっても minibatch の内容によって normalize の結果が変わります。VBN は、この依存性を弱め、training を安定させる目的で使われます。

Input noise

Real distribution prp_r と generated distribution pgp_g の support が disjoint であると、Discriminator は二つを完全に分けやすくなります。その結果、Vanishing Gradient in GAN が起きやすくなります。

この問題への一つの対処法は、Discriminator の input に continuous noise を加えることです。Noise を加えると distribution が人工的に広がり、prp_rpgp_g の overlap が増える可能性があります。Overlap が増えれば、Discriminator が完全に分離することが難しくなり、Generator が有用な gradient を受け取りやすくなります。

より良い metric を使う

Vanilla GAN の loss は、optimal Discriminator のもとでは Jensen-Shannon Divergence と関係します。しかし、support が重ならない場合、JS Divergence は有用な training signal を提供しにくくなります。

この問題を避けるために、Wasserstein GAN では Wasserstein Distance が使われます。

数式で見る non-saturating loss と regularization

Vanilla GAN の generator loss は saturation しやすいため、実用では non-saturating loss がよく使われます。

LGNS=Ezpz[logD(G(z))]\mathcal{L}_G^{NS}=-\mathbb{E}_{z\sim p_z}\left[\log D(G(z))\right]

この式の気持ちは、「fake sample を real と判定させる確率を直接上げる」ことです。Discriminator が強すぎて D(G(z))0D(G(z))\approx 0 でも、gradient が比較的残りやすくなります。

Discriminator の過学習を抑えるために、R1 regularization も使われます。

LR1=γ2Expr[xD(x)2]\mathcal{L}_{R1}=\frac{\gamma}{2}\mathbb{E}_{x\sim p_r}\left[\|\nabla_x D(x)\|^2\right]

この式は、「real data 近傍で discriminator の勾配を大きくしすぎない」ことを促します。Discriminator が細かい artifact に過剰適合するのを抑え、generator により滑らかな signal を返す効果があります。

関連ページ