GAN 的数学理论及推导(李宏毅GAN课程学习笔记)
在前一篇文章,我们已经从概念及逻辑上,对GAN进行了相对直观易懂的介绍。在这一篇文章中,我们来学习GAN背后的数学理论支撑。
还是以生成二次元人物头像为例,假设每个图都是一个64*64的高维向量,能生成二次元人物头像的高维向量必然只是这个高维空间中的一个固定分布,下图中蓝色的区域表示这个分布,其中的每一个点都是一个很可能可以生成头像的高维向量(为了直观我们用二维来图示)。
在分布区域(蓝色区域)内的向量有比较大的概率生成头像,而分布以外的概率较小。显然,我们的目的就是找到这个未知的分布,我们手里只有样本,并不知道分布的其他信息。
回想一下,对于要求得一个分布这种问题我们都是怎么解决的呢?一般是使用极大似然估计,关于极大似然估计我们在极大似然估计、MLE与MAP中做过比较详细的解读了,极大似然估计是经验风险的最小化,也可以通过KL散度理解为与真实分布最相近的分布,复习一下其步骤:
- 假设概率服从某种确定的概率分布(或者分布已知);
- 写出似然函数:;
- 对似然函数取对数,并整理为;
- 求导数;
- 解似然方程,得到极大似然的参数作为分布的参数,得到分布;
显然,如果使用极大似然估计的话,我们需要知道分布的类型,这是很难的推断的,因为我们只有高维的样本而已,很难说是高斯分布或者什么分布;再者,如果用一个神经网络来表示,那么就没法计算似然,因此在这个问题上,MLE应该是行不通的。来看看GAN中的Generator是怎么做的呢。
GAN中的Generator是一个神经网络,这个神经网络定义了的概率分布,就是说从一个分布里随便给Generator的网络一个输入,网络就会计算一个输出,这个输出就是符合的分布的。
那作为输入的应该服从什么分布呢?可以实验看看哪种效果好,其实可能影响比较小,因为是一个NN,NN是可以拟合非常复杂的函数的,就算输入非常简单,在处理后应该也是可以比较好的拟合出目标分布的,至少是从能力上来说应该没问题。
假设我们输入来自一个正态分布,如果把这些输出集合起来就是一个非常复杂的分布,记为:,我们的目标就是让输出的结果尽可能的与真实分布相似:
以数学语言来描述我们的目标:
最终我们的问题就是,怎样计算和真实分布之间的divergence,如果知道怎么算的话,我们就能通过最小化这个divergence来得到最接近真实分布的了。
在上一篇中我们已经知道,GAN中的生成器Generator用来生成图片,接收一个随机的噪声z,通过z生成图片G(z),也就是我们上面求分布所要实现的功能;判别器Discriminator用来判别一张图片是不是“真实的”,即计算生成的结果与真实数据的差别,也就是说计算和真实分布之间的divergence就是咱们判别器的工作了。
要计算和之间的divergence,梳理下我们现在都掌握了他们的哪些信息:
- 对,我们只有样本,其他全部未知;
- 对,如果我们假设生成器已知,那么我们就掌握了的一切。
要计算两者的区别,我们只能从对两者都有认知的点开始 —— 样本。我们可以从真实样本中抽样作为的样本,也可以通过已知的来生成样本作为的样本。
现在问题转化成通过的样本和的样本来计算和之间的divergence的问题了,真实分布的样本我们需要判别器给高分,生成的样本需要给低分,是不是就是二分类?
的样本是正样本,的样本是负样本,通过得分来划分成两类,像不像分类里的逻辑回归?别急,看了目标函数就觉得更像了,简直一毛一样:
前面一项是表示数据来自 ,值越大越好,后面一项是表示数据来自 ,值越小越好。
综上所述,计算和之间的divergence的问题,转变成了通过对两个分布的样本进行训练,得到一个能区分两个分布的样本的分类器的问题了,训练的目标如下:
2.1 Discriminator与JS散度
为什么说上面的目标函数与JS散度相关呢?
对于固定的,有:
要最大化,对于一个的取值来说,就相当于要最大化。对于固定的,令,,,可得:
求的极大值,只需要求导求极值即可:
求解可得:
代入目标函数:
考虑到JS散度的公式:
我们对目标函数的分子分母同乘1/2,得到:
Discriminator的目标函数最大化后得到的值,就是,可以说Discriminator最终衡量了两个分布间的JS散度。
3.1 理解训练过程
综上所述,Generator希望通过调整来最小化和真实分布之间的divergence,Discriminator在已有的基础上,训练样本得到分类器,来衡量与之间的JS Divergence,真是天造地设的一对:
这个式子看起来,一会一会,感觉还挺复杂的,我们用图像还形式化的理解一下整个过程:
假设现在我们有3个可能的,目标是找到使、最相近的,横坐标是,不同的位置代表不同的,我们先不管求,找到每个对应的的,如图中的红点,然后固定,找对应的:
最终找到了,它使得 和的Divergence最小。根据这个过程,我们可以得出Generator和Discriminator训练的过程:
3.1 训练过程算法
整个训练过程主要涉及到各种、,我们首先想到的是梯度下降、梯度上升。
使用梯度下降:
这里有个问题:,是不是可以求偏导的?其实是没问题的,想想我们在CNN中Max Pooling操作,也是求,不也是可以做梯度下降的吗。如下图所示,只要在处判断一下当前那个函数值最大,以这个函数对求偏导就可以了:
下面分别看看Generator和Discriminator是怎么训练的:
1)Generator
- Given ;
- Find maximizing , is the JS divergence between and ;
这里我们要注意一个问题,一开始的,我们是在给的的条件下训练出来的判别器,可以衡量与的Divergence,但是在训练后,我们更新了生成器,变成了,如果变化比较大的话,这时候分布就发生变化了,的函数曲线也会发生变化,也不再能衡量与的Divergence了:
所以我们跟相比应该更新的小一点,并且每轮训练只更新一次,这个时候从到的曲线变化不会很大,我们就假设=。还可以继续用来衡量变化后 。这也是GAN在训练中的一点技巧:
- Generator不要一次update太多,也不需要太多的iteration;
- Discriminator可以训练到收敛。因为要找到最大值才能衡量出JSD。
2)Discriminator
判别器的训练比较好理解,基本上就是一个二分类的训练过程:
根据上图可以发现,理论上是要取期望值,但是实际上无法操作,使用样本的均值进行估计。
现在,结合我们这次讲的理论内容,再来看看GAN的训练:
- Initialize for D and for G
- In each training iteration:
- Sample examples from data distribution
- Sample noise samples from the prior
- Obtaining generated data,
- Update discriminator parameters to maximize
- Sample another noise samples from the prior
- Update generator parameters to minimize
直观的过程:
黑色线表示真实数据的分布,绿色线表示生成数据的分布,蓝色线表示生成数据在判别器中的分布效果(a)判别网络D未经过训练,分类能力有限,有波动,但是基本可以区分真实数据x和生成数据G(z);
(b)判别网络D训练后,可以很好的区分出生成数据G(z);
(c)生成数据向真实数据靠近,生成器能力提升了;
(d)生成器生成数据G(z)已经和真实数据基本相同了,判别器打分收敛到一个稳定值,完全无法区分真实数据和生成数据了。
按上面的步骤,Generator训练时的目标函数:
而在Goodfellow的原论文中,,这两个函数的图像:
这两个函数的趋势是一样的,但是在同一位置的斜率是不一样的,在一开始训练的时候,是很小的,因为判别器很容易就分辨出来这是生成的fake的,这时候的值在很左边,微分是很小的,训练起来很慢,而不同,微分更大,更容易训练;而且使用实现起来很方便,把生成对象和真实对象的标签换一下就可以训练了。后来对使用的称为Minimax GAN (MMGAN),对使用的称为Non-saturating GAN (NSGAN)。
主要参考
对抗生成网络(GAN)教程(2018) 李宏毅
网友评论