1.【生成模型开山之作】GAN 论文精读

约 1762 字大约 6 分钟...

1.【生成模型开山之作】GAN 论文精读

原文连接:Generative Adversarial Netsopen in new window 源码连接:http://www.github.com/goodfeli/adversarialopen in new window

0. 核心总结

本文的核心是通过训练一组生成器-辨别器的对抗网络的思想去得到我们想要的生成网络,本文设计了一个精妙的损失函数和训练算法,使得生成器能最终能收敛到真实数据分布,辨别器能收敛到50%的辨别概率。

1. 摘要

GAN的框架包含两个模型:一个生成模型GG和辨别模型DD,通过同时训练这两个模型进行对抗, GG的目标是通过捕获数据分布来使DD犯错的概率最大,而DD的目标是估计数据是来自于训练集而不是GG。随着对抗训练的完成,GG能够恢复训练集的分布且DD的分辨概率为50%。

模型GGDD都是由神经网络构建,并且通过反向传播来训练。

2. 对抗网络

GAN网络的目标通过数学的表示是:

  • 生成模型GG:输入一个噪声随机变量pz(z)p_z(z),最优化参数θg\theta_g,使得模型G(z;θg)G(z;\theta_g)能够将输入映射到真实数据x\pmb{x}
  • 辨别模型DD:输入一个数据x\pmb{x},最优化参数θd\theta_d,输出是真实数据的概率

于是可以构建如下损失函数V(G,D)V(G,D)来对两个参数进行最优化:

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))] \min\limits_G \max\limits_D V(D,G)=\mathbb{E}_{\pmb{x}\sim p_{\text{data}(\pmb{x})}}[\log D(\pmb{x})]+\mathbb{E}_{z\sim p_z(z)}[\log (1-D(G(z)))]

根据以下公式,期望的梯度等于原梯度,所以可以使用梯度下降法来进行反向传播:

limσ0xEϵN(0,σI)f(x+ϵ)=xf(x) \lim\limits_{\sigma \to 0}\nabla_{\pmb{x}}\mathbb{E}_{\epsilon\sim N(0,\sigma \pmb{I})} f(\pmb{x}+\epsilon) = \nabla_{\pmb{x}}f(\pmb{x})

下图展示了一个GAN网络对抗训练的过程:

image.png
image.png

其中,蓝色虚线表示判别模型的分布DD,黑色虚线表示数据真实分布pdata(x)p_{\text{data}}(x),绿色实现表示生成分布pg(x)p_g(x)。两条水平线表示被采样的区域,zz是被均匀采样的,向上的箭头表示x=G(z)x=G(z)的映射。

  • (a): pgp_g近似pdatap_\text{data}DD开始只能部分辨别准确
  • (b):在算法内循环中,先最优化训练DDDD的最优解为:D(x)=pdata(x)pdata(x)+pg(x)D^*(x)=\frac {p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}
  • (c):在训练完DD后,DD的梯度会引导训练G(z)G(z)更接近真实分布
  • (d):最终如果训练收敛,则pg(x)=pdata(x)p_g(x) = p_{\text{data}}(x)D(X)=12D(X)=\frac 12

GAN网络训练算法的伪代码如下:

for iterations do

for kk steps do

  • pz(z)p_z(z)中采样mm个minibatch噪声{z(1),,z(m)}\{z^{(1)}, \ldots ,z^{(m)}\}
  • pdata(x)p_{\text{data}}(x)中采样mm个minibatch真实分布{x(1),,x(m)}\{x^{(1)}, \ldots ,x^{(m)}\}
  • 通过随机梯度下降算法更新辨别器:

θd1mi=1m[logD(x(i))+log(1D(G(z(i))))] \nabla_{\theta_d} \frac 1m \sum^m_{i=1} [\log D(x^{(i)})+\log (1-D(G(z^{(i)})))]

end for

  • pz(z)p_z(z)中采样mm个minibatch噪声{z(1),,z(m)}\{z^{(1)}, \ldots ,z^{(m)}\}
  • 通过随机梯度下降算法更新生成器:

θg1mi=1mlog(1D(G(z(i)))) \nabla_{\theta_g} \frac 1m \sum^m_{i=1} \log(1-D(G(z^{(i)})))

end for

梯度下降算法可以使用任意基于梯度的学习算法,本文使用了Momentum。

其中,kk为超参数。kk的设置很重要,若kk较小则辨别器DD变化较小导致生成器GG也变化较小;若kk较大则辨别器DD辨别能力很强导致D(G(Z))D(G(Z))的梯度接近于0,也影响GG的更新。

3. 理论证明

已知我们的目标是使生成数据等于真实数据,辨别器的分辨概率为50%。此节通过理论证明本文设计的损失函数的合理性,为什么通过最优化该损失函数可以收敛至 pg(x)=pdata(x)p_g(x) = p_{\text{data}}(x)D(X)=12D(X)=\frac 12

命题1. 对于给定的GG,最优的辨别器DD是:

DG(x)=pdata(x)pdata(x)+pg(x) D^*_G(x)=\frac {p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}

证明:辨别器DD训练的准则是,对于任意给定的GG,令V(G,D)V(G,D)最大。将损失函数展开,得:

V(G,D)=xpdata(x)log(D(x))dx+zpz(z)log(1D(G(z)))dz=xpdata(x)log(D(x))+pg(x)log(1D(x))dx \begin{aligned} V(G,D) &= \int_x p_{\text{data}}(x) \log (D(x))dx + \int_z p_z(z) \log(1-D(G(z)))dz \\ &= \int_x p_{\text{data}}(x) \log (D(x)) + p_g(x) \log(1-D(x))dx \end{aligned}

现要求V(G,D)V(G,D)关于DD的最大值,则固定GGDD的偏导,得:

D(x)(pdata(x)log(D(x))+pg(x)log(1D(x)))=pdata(x)D(x)pg(x)1D(x)=0,D(x)[0,1]D(x)=pdata(x)pdata(x)+pg(x) \begin{aligned} &\frac {\partial}{\partial D(x)} (p_{\text{data}}(x) \log (D(x)) + p_g(x) \log(1-D(x))) \\ &= \frac {p_{\text{data}}(x)}{D(x)} - \frac {p_g(x)}{1-D(x)} = 0, \quad D(x) \in [0,1] \\ &\Rightarrow D^*(x)=\frac {p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)} \end{aligned}

根据命题1的结果,损失函数可以重新表示为:

C(G)=maxDV(G,D)=Expdata[logDG(x)]+Ezpz[log(1DG(G(z)))]=Expdata[logDG(x)]+Expg[log(1DG(x))]=Expdata[logpdata(x)pdata(x)+pg(x)]+Expg[logpg(x)pdata(x)+pg(x)] \begin{aligned} C(G) &= \max_D V(G,D) \\ &= \mathbb{E}_{x\sim p_{\text{data}}}[\log D^*_G(x)]+\mathbb{E}_{z\sim p_z}[\log (1-D^*_G(G(z)))] \\ &= \mathbb{E}_{x\sim p_{\text{data}}}[\log D^*_G(x)]+\mathbb{E}_{x\sim p_g}[\log (1-D^*_G(x))] \\ &= \mathbb{E}_{x\sim p_{\text{data}}}[\log \frac {p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)}]+\mathbb{E}_{x\sim p_g}[\log \frac {p_g(x)} {p_{\text{data}}(x)+p_g(x)}] \end{aligned}

定理1. 当且仅当 pg(x)=pdata(x)p_g(x) = p_{\text{data}}(x)时,C(G)C(G)达到全局最小值,最小值为log4-\log 4

证明:对C(G)C(G)做如下变换:

C(G)=xpdata(x)logpdata(x)pdata(x)+pg(x)+pg(x)logpg(x)pdata(x)+pg(x)dx=xpdata(x)logpdata(x)2pdata(x)+pg(x)2+pg(x)logpg(x)2pdata(x)+pg(x)2dx=log4+xpdata(x)logpdata(x)pdata(x)+pg(x)2+pg(x)logpg(x)pdata(x)+pg(x)2dx=log4+KL(pdatapdata+pg2)+KL(pgpdata+pg2)=log4+2JS(pdatapg) \begin{aligned} C(G) &= \int_x p_{\text{data}}(x) \log \frac {p_{\text{data}}(x)} {p_{\text{data}}(x)+p_g(x)} + p_g(x) \log \frac {p_g(x)} {p_{\text{data}}(x)+p_g(x)} dx \\ &= \int_x p_{\text{data}}(x) \log \frac{\frac{p_{\text{data}}(x)}2}{\frac{p_{\text{data}}(x)+p_g(x)}2} + p_g(x) \log \frac{\frac{p_g(x)}2}{\frac{p_{\text{data}}(x)+p_g(x)}2} dx \\ &= -\log4 + \int_x p_{\text{data}}(x) \log \frac{p_{\text{data}}(x)}{\frac{p_{\text{data}}(x)+p_g(x)}2} + p_g(x) \log \frac{p_g(x)}{\frac{p_{\text{data}}(x)+p_g(x)}2} dx \\ &= -\log4 + KL(p_{\text{data}} \parallel \frac{p_{\text{data}}+p_g}2) + KL(p_g \parallel \frac{p_{\text{data}}+p_g}2) \\ &= -\log4 + 2 \cdot JS(p_{\text{data}} \parallel p_g) \end{aligned}

根据JS散度的非负性,JS(pdatapg)0JS(p_{\text{data}} \parallel p_g) \geq 0,且当且仅当pg=pdatap_g = p_{\text{data}}时,JS(pdatapg)=0JS(p_{\text{data}} \parallel p_g) = 0,因此C(G)C(G)有最小值log4-\log 4

如果不了解KL和JS散度,可以阅读我的这篇文章:信息量和熵open in new window

定理2.GGDD有足够容量的时候,且在训练算法中DD可以达到其最优解,如果对pgp_g的优化是按照如下公式,那么pgp_g最终可以收敛到 pdatap_{\text{data}}

Expdata[logDG(x)]+Expg[log(1DG(x))] \mathbb{E}_{x\sim p_{\text{data}}}[\log D^*_G(x)]+\mathbb{E}_{x\sim p_g}[\log (1-D^*_G(x))]

证明:把V(G,D)=U(pg,D)V(G,D) = U(p_g,D)看成是一个关于pgp_g的函数,且是凸函数,由于一个凸函数的积分上限函数还是凸函数,所以对凸函数做梯度下降时会得到一个最优解。

实际上,训练算法每次迭代并没有对DD优化到极致,只是迭代了kk步,但实践中训练算法的表现效果很好。

上次编辑于:
贡献者: lisenjie757
已到达文章底部,欢迎留言、表情互动~
  • 赞一个
    0
    赞一个
  • 支持下
    0
    支持下
  • 有点酷
    0
    有点酷
  • 啥玩意
    0
    啥玩意
  • 看不懂
    0
    看不懂
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v2.14.9