美文网首页
Auto-Encoding Variational Bayes

Auto-Encoding Variational Bayes

作者: 馒头and花卷 | 来源:发表于2020-04-02 19:05 被阅读0次

    Kingma D P, Welling M. Auto-Encoding Variational Bayes[J]. arXiv: Machine Learning, 2013.

    主要内容

    自编码, 通过引入Encoder和Decoder来估计联合分布p(x,z), 其中z表示隐变量(我们也可以让z为样本标签, 使得Encoder成为一个判别器).

    在Decoder中我们建立联合分布p_{\theta}(x,z)以估计p(x,z), 在Encoder中建立一个后验分布q_{\phi}(z|x)去估计p_{\theta}(z|x), 然后极大似然:
    \begin{array}{ll} \log p_{\theta}(x) &= \log \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ & = \log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)} \\ \end{array},
    上式俩边关于z在分布q_{\phi}(z)下求期望可得:
    \begin{array}{ll} \log p_{\theta}(x) & = \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} + \log \frac{q_{\phi}(z|x)}{p_{\theta}(z|x)}) \\ &= \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )+D_{KL}(q_{\phi}(z|x)\| p_{\theta}(z |x ))\\ & \ge \mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} ) \end{array}.

    既然KL散度非负, 我们极大似然\log p_{\theta}(x)可以退而求其次, 最大化\mathbb{E}_{q_{\phi}(z|x)}(\log \frac{p_{\theta}(x,z)}{q_{\phi}(z|x)} )(ELBO, 记为\mathcal{L}).

    又(p_{\theta}(z)为认为给定的先验分布)
    \begin{array}{ll} \mathcal{L}(\theta, \phi; x) &= -D_{KL}(q_{\phi}(z|x)\|p_{\theta}(z))+\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)], \end{array}
    我们接下来通过对Encoder和Decoder的一些构造进一步扩展上面俩项.

    Encoder (损失part1)

    Encoder 将x\rightarrow z, 就相当于在q_{\phi}(z|x)中进行采样, 但是如果是直接采样的话, 就没法利用梯度回传进行训练了, 这里需要一个重参化技巧.

    我们假设q_{\phi}(z|x)为高斯密度函数, 即\mathcal{N}(\mu, \sigma^2 I).
    注: 文中还提到了其他的一些可行假设.

    我们构建一个神经网络f, 其输入为样本x, 输出为(\mu, \log \sigma)(输出\log \sigma是为了保证\sigma为正), 则
    z= \mu + \epsilon \odot \sigma, \epsilon \sim \mathcal{N}(0, I),
    其中\odot表示按元素相乘.
    注: 我们可以该输出为(\mu, L)(L为三角矩阵, 且对角线元素非负), 而假设q_{\phi}(z|x)的分量不独立, 其协方差函数为L^TL, 则(z=\mu + L \epsilon).

    p_{\theta}(z)=\mathcal{N}(0, I), 我们可以显示表达出:

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    Decoder (损失part2)

    现在我们需要处理的是第二项, 文中这地方因为直接设计p_{\theta}(x,z)不容易, 在我看来存粹是做不到的, 但是又用普通的分布代替不符合常理, 所以首先设计一个网络g_{\theta}(z), 其输出为\hat{x}, 然后假设p(x|\hat{x})的分布, 第二项就改为近似\mathbb{E}_{q_{\phi}(z|x)}p_{\theta}(x|\hat{x}).

    这么做的好处是显而易见的, 因为Decoder部分, 我们可以通过给定一个z然后获得一个\hat{x}, 这是很有用的东西, 但是我认为这种不是很合理, 因为除非g是可逆的, 那么p_{\theta}(x|z)= p _{\theta}(x|\hat{x}) (当然, 别无选择).

    伯努利分布

    此时\hat{x}=g(z)x=1的概率, 则此时第二项的损失为
    \log p(\mathbf{x}| \hat{\mathbf{x}})= \sum_{i=1} x_i \log \hat{x}_i + (1-x_i) \log (1- \hat{x}_i),
    为(二分类)交叉熵损失.

    高斯分布

    一种简单粗暴的, p(x|\hat{x})=\mathcal{N}(\hat{x},\sigma^2 I), 此时损失为类平方损失, 文中也有别的变换.

    代码

    import torch
    import torch.nn as nn
    
    
    class Loss(nn.Module):
        def __init__(self, part2):
            super(Loss, self).__init__()
            self.part2 = part2
    
        def forward(self, mu, sigma, real, fake, lam=1):
            part1 = (1 + torch.log(sigma ** 2)
                     - mu ** 2 - sigma ** 2).sum() / 2
            part2 = self.part2(fake, real)
            return part1 + lam * part2
    

    相关文章

      网友评论

          本文标题:Auto-Encoding Variational Bayes

          本文链接:https://www.haomeiwen.com/subject/ofrfphtx.html