Gan网络

作者: 愤怒的熊猫V | 来源:发表于2019-08-28 21:06 被阅读0次

本文所有图片来自北京睿邦科技有限公司Gan网络教学课程教师王昊然。

判别模型与生成模型

在说Gan网络之前有必要先讲一讲什么是判别模型和生成模型,这里也是常问的一个问题,当时只记得判别模型是学习一个条件概率分布P (Y\vert X),而生成模型是学习一个联合概率分布P(X,Y),然后用公式P(Y\vert X)=\frac{P(X,Y)}{P(X)} 来完成一些诸如分类之类的任务,其实这个东西说起很抽象,很难以理解,但是接触到了Gan网络之后,才慢慢对一些概念有了进一步的认识。如下图所示,图来自网络。

判别模型与生成模型

可以清楚地看到,判别模型是根据一些具有代表性的特征数据来找到一条所谓的“分界线”来对我们的结果进行预测,而生成模型则是学习到了我们的输入数据样本与最终希望得到的输出结果之间的分布关系,由于早年间的数据量或者算力受限,使用的多是判别式模型,但随着硬件不断升级以及大数据时代的到来,人们已经不满足于使用机器进行预测这一项工作,还要使得机器完成模拟这一项工作。如果对于预测和模拟这一概念还不太清楚,可以看下面这张图,依然来自网络。

生成与判别举例

生成模型还能解决带有隐变量的问题,但判别模型就无法解决,因为它缺少一些关键特征。

以语言为例,假设我们的任务是要区分一句话是中文还是英文还是法文,生成式模型先做的是找到中文的一些构成规则,再找到英文的构成规则,再找到法文的构成规则,然后对于给定的一句话,来依次对应这些规则去进行计算当我这句话来的时候,在我的这些分布当中属于哪一类语言的概率是最高的;而判别式模型做的就是找到这三者语言之间的一些明显的差别,这种差别我们叫做特征,例如一句话如果有4个声调,那么我们可以根据声调数量这一特征来判别这一句话属于哪种语言。

再以图像为例,假设我们要区分这两张人脸是谁,范冰冰还是佟丽娅,那么对生成式模型来说,我就可以学习到属于范冰冰的人脸构成分布和佟丽娅的人脸构成分布,然后再对输入图像计算条件概率;如果是判别式模型,我们就直接对一些关键特征例如眼睛的长度,眉毛的宽度,脸型等等来判别输入的人脸属于哪个人。

Gan网络原理

之所以要先将生成式模型与判别式模型的一个重要原因就是Gan网络结合了生成式模型,我们叫它生成器G,和判别式模型,我们叫它N。

以Gan网络常用的场景图像问题来说,我们的生成器就是使用一堆随机的高斯噪声点来不断地去进行调整直到和我们期望的输出达到相同的分布,也就是说能生成我们想要的照片;而如何达到我们想要生成的图片则完全由判别器来说了算,它可以对生成器生成的结果来做一个判定。

如果得分为0,那么就代表生成的图片看上去太假了,还需要调整生成器的能力;如果得分为1,那么由于图片本身就是一张假图,这只能说明我们的判别器判别能力太弱了;只有当判别器得分为0.5的时候才能说明我们生成的图已经达到了真假难辨的地步,这其实就是达到了我们博弈论中所说的纳什均衡点。

Gan网络示意

原论文中证明了生成器生成的分布一定会收敛,也就是与真实分布完全重合,图中绿色分布代表生成器分布

生成器分布图示意

对于任务的实际解决,我们需要先训练判别器,然后训练生成器,然后交替不断训练达到纳什均衡,然后最后我们往往需要的是生成器,记住这两者训练的先后顺序。

Gan目标函数

D是判别器,G是生成器,我们先把G看做一个常量来看这个函数,对于这个公式,D(x)增大前半部分增大后半部分减小,所以达到一个0.5的均衡才能最大,这也符合我们的先验知识;再看生成器,要想使minG最小,前半部分与G无关,那么只有当G不断的增大才能使minG最小;因此这个函数表达的是想要使G不断增大,而D不断趋向于均衡。

Gan网络家族

CGAN:Gan网络只能生成一种类别,例如判别器喂进去人民币,美元,那么它的随机性就不能很好帮助我们完成任务,为了解决这个问题,CGAN就诞生了。

Seq-Gan:最早起的Gan网络不太适合生成离散数据,例如我们上面讲到的生成的分布都是连续的,为了解决这个问题,生成一些离散的分布,例如文本,Seq-Gan能生成离散序列的生成式模型。

Pix2Pix,CycleGan,StarGan:图像翻译


Gan网络训练过程

在例中的手写数字生成demo中,生成器和判别器都是全连接网络。

以手写体数字的生成为例,定义一个discriminator判别器,再定义一个generator生成器,高斯噪声Z为一个100维的噪声。

假设我们的真实图片是MNIST数据集,我们无需给它打上标签,假设我们的Batch是32,那么我们就生成一个32维的全1的Valid向量,再生成一个32维的全0的Fake向量。

判别器训练,记住我们上面所说的,我们先要训练判别器Discriminator,我们由32*100的高斯噪声经由generator生成32张Fake图;然后我们将32张真实图片输入,以Valid全1向量作为标签,告诉判别器,这些图片全部都是真的;然后我们将32张由generator生成的Fake图输入,以Fake全0向量作为标签,告诉判别器,这些图片全部都是假的。我们把所有的Loss_real和Loss_fake加起来然后取平均值。

生成器训练,我们高斯噪声Z再重新输入到生成器,然后再把生成的Fake图输入到判别器中,然后再用全1来作为标签进行训练和反向传播。

至此,一个batch的训练就完成了,这里的损失函数用的是二分类交叉熵损失函数Binary_crossentropy,也就是softmax的一个基础款。

最开始训练的时候Loss_real和LOSS_fake后面接的准确率会是100%,但随着训练进行,最终达到50%,就代表整个网络收敛了,那么我们的生成器就能用啦~

CGAN

之前我们讲的基础的Gan网络输入都没有标签,这也就造成了我们会随机产生0-9中的图片,而在CAGN中,我们加入了标签。

CGAN

相关文章

网友评论

      本文标题:Gan网络

      本文链接:https://www.haomeiwen.com/subject/ubqyectx.html