美文网首页机器学习
WGAN的来龙去脉

WGAN的来龙去脉

作者: 6e845d5ac37b | 来源:发表于2018-12-23 11:53 被阅读49次

    渐渐体会到GAN训练难确实是一个让人头疼的问题,一个多月前我曾粗略地了解了一下WGAN,知道这是一个着眼于提高GAN训练稳定性的成果,但后来发现,我对其原理理解得还不是很充足。于是我把WGAN的一作作者Martin Arjovsky在2017年先后参与的三篇相关论文找来看,对WGAN的来龙去脉有了一个更清晰的理解。

    Towards Principled Methods for Training GenerativeAdversarial Networks

    这篇论文是WGAN发表前的铺垫,它最大的贡献是从理论上解释了GAN训练不稳定的原因。

    人们在应用GAN时经常发现一个现象:不能把Discriminator训练得太好,否则Generator的性能很难提升上去。该文以此为出发点,分析了GAN目标函数的理论缺陷。

    原始目标函数的缺陷

    在最早提出GAN的论文中,Goodfellow把GAN的目标函数设置为:

    他也证明了,固定Generator时,最优的Discriminator是

    然后在面对最优Discriminator时,Generator的优化目标就变成了

    可以把上述公式简洁地写成JS散度的形式:

    也就是说,如果把Discriminator训练到极致,那么整个GAN的训练目标就成了最小化真实数据分布与合成数据分布之间的JS散度。

    该文花了大量的篇幅进行数学推导,证明在一般的情况下,上述有关JS散度的目标函数会带来梯度消失的问题。也就是说,如果Discriminator训练得太好,Generator就无法得到足够的梯度继续优化,而如果Discriminator训练得太弱,指示作用不显著,同样不能让Generator进行有效的学习。这样一来,Discriminator的训练火候就非常难把控,这就是GAN训练难的根源。

    该文还用实验对这一结论进行了验证:让Generator固定,然后从头开始训练Discriminator,绘制出Generator目标函数梯度和训练迭代次数的关系如下。

    可以看到,经过25 epochs的训练以后,Generator得到的梯度已经非常小了,出现了明显的梯度消失问题。

    -logD目标函数的缺陷

    Goodfellow提到过可以把Generator的目标函数改为-logD的形式,在实际应用中,人们也发现这个形式更好用,该文把这个技巧称为the - log D alternative。此时Generator的梯度是:

    该文证明在最优的Discriminator下,这个梯度可以转化为KL散度和JS散度的组合:

    该文对这一结论有两点评论:

    1. 该公式的第二项意味着最大化真实数据分布和生成数据分布之间的JS散度,也就是让两者差异化更大,这显然违背了最初的优化目标,算是一种缺陷。

    2. 同时,第一项的KL散度会被最小化,这会带来严重的mode dropping问题。

    关于上述第二点,下面补充一点说明。

    mode dropping在更多的情况下被称作mode collapse,指的是生成样本只集中于部分的mode从而缺乏多样性的情况。例如,MNIST数据分布一共有10个mode(0到9共10个数字),如果Generator生成的样本几乎只有其中某个数字,那么就是出现了很严重的mode collapse现象。

    接下来解释为什么上述的KL散度

    会导致mode collapse。借用网上某博客的图,真实的数据分布记为P,生成的数据分布记为Q,图的左边表示两个分布的轮廓,右边表示两种KL散度的分布(由于KL散度的不对称性,KL(P||Q)与KL(Q||P)是不同的)。

    右图蓝色的曲线代表KL(Q||P),相当于上述的

    可以看到,KL(Q||P)会更多地惩罚q(x) > 0而p(x) -> 0的情况(如x = 2附近),也就是惩罚“生成样本质量不佳”的错误;另一方面,当p(x) > 0而q(x) -> 0时(如x = -3附近),KL(Q||P)给出的惩罚几乎是0,表示对“Q未能广泛覆盖P涉及的区域”不在乎。如此一来,为了“安全”起见,最终的Q将谨慎地覆盖P的一小部分区域,即Generator会生成大量高质量却缺乏多样性的样本,这就是mode collapse问题。

    另外,通过类似的分析可以知道,KL(P||Q)则会导致Generator生成多样性强却低质量的样本。

    除了上述的缺陷,该文还通过数学证明这种-logD的目标函数还存在梯度方差较大的缺陷,导致训练的不稳定。然后同样通过实验直观地验证了这个现象,如下图,在训练的早期(训练了1 epoch和训练了10 epochs),梯度的方差很大,因此对应的曲线看起来比较粗,直到训练了25 epochs以后GAN收敛了才出现方差较小的梯度。

    小结

    该文通过严谨的理论推导分析了当前GAN训练难的根源:原始的目标函数容易导致梯度消失;改进后的-logD trick虽然解决了梯度消失的问题,然而又带来了mode collapse、梯度不稳定等问题,同样存在理论缺陷。既然深入剖析了问题的根源,该文自然在最后也提出了一个解决方案,然而该方案毕竟不如后来的WGAN那样精巧,因此我把这部分略过了。

    Wasserstein GAN

    EM距离

    这是最早提出WGAN的论文,沿着上篇论文的思路,该文认为需要对“生成分布与真实分布之间的距离”探索一种更合适的度量方法。作者们把眼光转向了Earth-Mover距离,简称EM距离,又称Wasserstein距离。

    EM距离的定义为:

    解释如下:\Pi (P_{r}, P_{g})P_{r}P_{g}组合起来的所有可能的联合分布的集合,对于每一个可能的联合分布\gamma而言,可以从中采样(x, y) \sim \gamma得到一个真实样本x和一个生成样本y,并算出这对样本的距离||x-y||,所以可以计算该联合分布下样本对距离的期望值\mathbb{E}_{(x, y) \sim \gamma}[||x-y||]。在所有可能的联合分布中能够对这个期望值取到的下界,就定义为EM距离。

    Earth-Mover的本意是推土机的意思,这个命名很贴切,因为从直观上理解,EM距离就是在衡量把Pr这堆“沙土”“推”到Pg这个“位置”所要花费的最小代价,其中的γ就是一种“推土”方案。

    该文接下来又通过数学证明,相比JS、KL等距离,EM距离的变化更加敏感,能提供更有意义的梯度,理论上显得更加优越。

    WGAN

    作者们自然想到把EM距离用到GAN中。直接求解EM距离是很难做到的,不过可以用一个叫Kantorovich-Rubinstein duality的理论把问题转化为:

    这个公式的意思是对所有满足1-Lipschitz限制的函数f取到\mathbb{E}_{x \sim \mathbb{P}_{r} }[f(x)] - \mathbb{E}_{x \sim \mathbb{P}_{\theta} }[f(x)]的上界。简单地说,Lipschitz限制规定了一个连续函数的最大局部变动幅度,如K-Lipschitz就是:|f(x_{1}) - f(x_{2}) | \le K|x_{1} - x_{2}|

    然后可以用神经网络的方法来解决上述优化问题:

    这个神经网络和GAN中的Discriminator非常相似,只存在一些细微的差异,作者把它命名为Critic以便与Discriminator作区分。两者的不同之处在于:

    1. Critic最后一层抛弃了sigmoid,因为它输出的是一般意义上的分数,而不像Discriminator输出的是概率。

    2. Critic的目标函数没有log项,这是从上面的推导得到的。

    3. Critic在每次更新后都要把参数截断在某个范围,即weight clipping,这是为了保证上面讲到的Lipschitz限制。

    4. Critic训练得越好,对Generator的提升更有利,因此可以放心地多训练Critic。

    这样的简单修改就是WGAN的核心了,虽然数学证明很复杂,最后的变动却十分简洁。总结出来的WGAN算法为:

    GAN与WGAN的对比如下图:

    GAN WGAN

    WGAN的优越之处

    最后,该文用一系列的实验说明了WGAN的几大优越之处:

    1. 不再需要纠结如何平衡Generator和Discriminator的训练程度,大大提高了GAN训练的稳定性:Critic(Discriminator)训练得越好,对提升Generator就越有利。

    2. 即使网络结构设计得比较简陋,WGAN也能展现出良好的性能,包括避免了mode collapse的现象,体现了出色的鲁棒性。

    3. Critic的loss很准确地反映了Generator生成样本的质量,因此可以作为展现GAN训练进度的定性指标。

    Improved Training ofWasserstein GANs

    紧接着上面的工作,这篇论文对刚提出的WGAN做了一点小改进。

    作者们发现WGAN有时候也会伴随样本质量低、难以收敛等问题。WGAN为了保证Lipschitz限制,采用了weight clipping的方法,然而这样的方式可能过于简单粗暴了,因此他们认为这是上述问题的罪魁祸首。

    具体而言,他们通过简单的实验,发现weight clipping会导致两大问题:模型建模能力弱化,以及梯度爆炸或消失。

    他们提出的替代方案是给Critic loss加入gradient penalty (GP),这样,新的网络模型就叫WGAN-GP

    GP项的设计逻辑是:当且仅当一个可微函数的梯度范数(gradient norm)在任意处都不超过1时,该函数满足1-Lipschitz条件。至于为什么限制Critic的梯度范数趋向1(two-sided penalty)而不是小于1(one-sided penalty),作者给出的解释是,从理论上最优Critic的梯度范数应当处处接近1,对Lipschitz条件的影响不大,同时从实验中发现two-sided penalty效果比one-sided penalty略好。

    另一个值得注意的地方是,用于计算GP的样本\hat {x}是生成样本和真实样本的线性插值,直接看算法流程更容易理解:

    最后,该论文也通过实验说明,WGAN-GP在训练的速度和生成样本的质量上,都略胜WGAN一筹。

    相关文章

      网友评论

        本文标题:WGAN的来龙去脉

        本文链接:https://www.haomeiwen.com/subject/sxogkqtx.html