1. GAN
GAN有两个模型,分别是生成模型和判别模型。这两个模型进行博弈,先保持生成模型不变,优化判别模型(判别是真实数据还是生成的数据),使得它能最大化区别真实和生成数据;保持判别模型不变,优化生成模型,使它能接近真实数据,让判别模型判别不出来。
1) 定义x是真实数据
2) 定义一个先验变量:(例如:在N(0,1)产生),生成模型G是一个以
为参数的多层感知器。把z通过G映射到与x相同维度的空间。
3)定义一个判别模型,D是一个以
为参数的多层感知器。它输出的是[0,1]的一个数。
损失函数:

在训练一开始,因为很容易判别是不是真实数据,是饱和状态,所以我们
。
具体的算法:

可以通过代码更好得理解GAN(链接中分别有tensorflow和pytorch版本): vanilla gan
GAN的问题:
1. 训练困难,训练过程要小心平衡生成器和判别器
2. 生成器和判别器的loss无法指示训练进程
3. 生成样本缺乏多样性,也就是mode collapse
2. WGAN
WGAN的出现解决了上述GAN的几个问题,(有点小激动),以下是wgan的算法,而改进后相比原始GAN的算法实现流程却只改了四点:
1)判别器最后一层去掉sigmoid
2)生成器和判别器的loss不取log
3)每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4)不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

附上知乎详细链接:wgan
网友评论