Skip to main content

Generative Adversarial Network

Generative Adversarial Network(以下 GAN)は、Generator と Discriminator という二つの model を競わせながら training する generative model です。画像、自然言語、音楽のような現実世界の豊かな content を生成するために使われます。

基本構造

GAN は、次の二つの model から構成されます。

Model役割
Generator GGnoise variable zz を受け取り、synthetic sample G(z)G(z) を生成します。
Discriminator DD与えられた sample が real dataset から来たものか、Generator が作った fake sample なのかを判定します。

Generator は、Discriminator をだませるほど本物らしい sample を作ろうとします。Discriminator は、real sample と fake sample を正しく見分けようとします。この二つの model の競争によって、Generator は real data distribution に近い sample を生成する方向へ更新されます。

GAN の構造

画像出典: Lilian Weng, “From GAN to WGAN”。Generator が noise から fake sample を作り、Discriminator が real sample と fake sample を見分ける構造が示されています。

記号

Symbol意味
pzp_znoise input zz の distribution です。多くの場合には uniform distribution が使われます。
pgp_gGenerator が作る data xx の distribution です。
prp_rreal sample xx の data distribution です。

Minimax objective

Discriminator は、real sample に対して高い確率を出し、fake sample に対して低い確率を出すように training されます。一方で、Generator は、fake sample に対して Discriminator が高い確率を出すように training されます。

この関係は、次の minimax game として書けます。

minGmaxDL(D,G)=Expr(x) ⁣[logD(x)]+Ezpz(z) ⁣[log(1D(G(z)))]\min_G \max_D L(D, G) = \mathbb{E}_{x \sim p_r(x)}\!\left[\log D(x)\right] + \mathbb{E}_{z \sim p_z(z)}\!\left[\log\bigl(1 - D(G(z))\bigr)\right]

Generator が作った distribution pgp_g を使うと、次のようにも書けます。

L(D,G)=Expr(x) ⁣[logD(x)]+Expg(x) ⁣[log(1D(x))]L(D, G) = \mathbb{E}_{x \sim p_r(x)}\!\left[\log D(x)\right] + \mathbb{E}_{x \sim p_g(x)}\!\left[\log\bigl(1 - D(x)\bigr)\right]

第一項は、Discriminator が real sample を正しく real と判断することを促します。第二項は、Discriminator が fake sample を正しく fake と判断することを促します。Generator は、この第二項を小さくする方向に更新されます。

Optimal Discriminator

Generator を固定したとき、Discriminator にとって最適な出力は次の形になります。

D(x)=pr(x)pr(x)+pg(x)D^{*}(x) = \frac{p_r(x)}{p_r(x) + p_g(x)}

ある xx が real distribution の中でよく現れ、generated distribution の中ではあまり現れない場合、D(x)D^{*}(x)11 に近づきます。反対に、ある xx が generated distribution の中でよく現れ、real distribution の中ではあまり現れない場合、D(x)D^{*}(x)00 に近づきます。

Generator が十分に良くなり、pg=prp_g = p_r が成り立つと、Discriminator は real sample と fake sample を区別できなくなります。このとき、すべての xx について D(x)=12D^{*}(x) = \frac{1}{2} になります。

Global optimum

pg=prp_g = p_r であり、D(x)=12D^{*}(x) = \frac{1}{2} であるとき、GAN は global optimum に到達しています。このとき、loss は次の値になります。

L(G,D)=2log2L(G, D^{*}) = -2\log 2

この状態では、Generator は real data distribution を再現しており、Discriminator は coin flip と同程度の判断しかできません。

Loss と JS Divergence の関係

Discriminator が optimal であるとき、GAN の loss は Jensen-Shannon Divergence と次のように関係します。

L(G,D)=2DJS(prpg)2log2L(G, D^{*}) = 2D_{JS}(p_r \,\|\, p_g) - 2\log 2

つまり、vanilla GAN は、pgp_gprp_r に近づけるように JS Divergence を小さくしていると解釈できます。ただし、real distribution と generated distribution の support が重ならない場合には、この見方が training の難しさにつながります。

関連ページ