Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.
主要内容
自编码, 通过引入Encoder和Decoder来估计联合分布, 其中表示隐变量(我们也可以让为样本标签, 使得Encoder成为一个判别器).
在Decoder中我们建立联合分布以估计, 在Encoder中建立一个后验分布去估计, 然后极大似然:
上式俩边关于在分布下求期望可得:
既然KL散度非负, 我们极大似然可以退而求其次, 最大化(ELBO, 记为).
又(为认为给定的先验分布)
我们接下来通过对Encoder和Decoder的一些构造进一步扩展上面俩项.
Encoder (损失part1)
Encoder 将, 就相当于在中进行采样, 但是如果是直接采样的话, 就没法利用梯度回传进行训练了, 这里需要一个重参化技巧.
我们假设为高斯密度函数, 即.
注: 文中还提到了其他的一些可行假设.
我们构建一个神经网络, 其输入为样本, 输出为(输出是为了保证为正), 则
其中表示按元素相乘.
注: 我们可以该输出为(为三角矩阵, 且对角线元素非负), 而假设的分量不独立, 其协方差函数为, 则).
当, 我们可以显示表达出:
在这里插入图片描述
在这里插入图片描述
Decoder (损失part2)
现在我们需要处理的是第二项, 文中这地方因为直接设计不容易, 在我看来存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先设计一个网络, 其输出为, 然后假设的分布, 第二项就改为近似.
这么做的好处是显而易见的, 因为Decoder部分, 我们可以通过给定一个然后获得一个, 这是很有用的东西, 但是我认为这种不是很合理, 因为除非是可逆的, 那么 (当然, 别无选择).
伯努利分布
此时是的概率, 则此时第二项的损失为
为(二分类)交叉熵损失.
高斯分布
一种简单粗暴的, , 此时损失为类平方损失, 文中也有别的变换.
代码
import torch
import torch.nn as nn
class Loss(nn.Module):
def __init__(self, part2):
super(Loss, self).__init__()
self.part2 = part2
def forward(self, mu, sigma, real, fake, lam=1):
part1 = (1 + torch.log(sigma ** 2)
- mu ** 2 - sigma ** 2).sum() / 2
part2 = self.part2(fake, real)
return part1 + lam * part2
网友评论