VI

作者: 圣元素 | 来源:发表于2019-01-11 00:01 被阅读0次

    varational inference

    背景

    在贝叶斯框架下,推断一般指的是后验分布,即p(\theta|x)=\frac{p(x|\theta)p(\theta)}{p(x)}= \frac{p(x|\theta)p(\theta)}{\int_{\theta}p(x,\theta)d\theta },我们并不知道后验分布的形式,甚至知道后验分布的形式,但仍然难以计算出该形式下的参数。因此,我们希望找到一种近似分布用于描述后验分布。

    变分推断

    log(p(x))=log(p(z,x))-log(p(z|x))

    =log(p(z,x))-log(q(z))-(log(p(z|x))-log(q(z)))

    =log\frac{p(z,x)}{q(z)} -log\frac{p(z|x)}{q(z)}(1)

    对(1)式左右两边对变量z\sim q(z)求期望得到

    \int_{z}q(z)log(p(x))dz=\int_{z}q(z)(log\frac{p(z,x)}{q(z)} -log\frac{p(z|x)}{q(z)})dz

    log(p(x))=E_{z\sim q}[log\frac{p(z,x)}{q(z)} -log\frac{p(z|x)}{q(z)}]

    =E_{z\sim q}[log\frac{p(z,x)}{q(z)}] +E_{z\sim q}[log\frac{q(z)}{p(z|x)}]

    =E_{z\sim q}[log(p(z,x))]- E_{z\sim q}[log(q(z))]+D_{KL}(q(z)\|p(z|x))(2)

    由于KL-Divergence恒大于0,有

    log(p(x))\geq E_{z\sim q}[log(p(z,x))]- E_{z\sim q}[log(q(z))](3)

    L(q)=E_{z\sim q}[log(p(z,x))]- E_{z\sim q}[log(q(z))]=E_{z\sim q}[log(p(z,x))]+H(q)

    如果能在q\in Q找到一个q^*=argmax_{q}L(q),此时D_{KL}(q^*(z)\|p(z|x))取到极小值

    等价于在Q中找到与p(z|x)最接近的分布。


    变分推断与神经网络

    背景

    如何使用一个神经网络来表示一个密度函数?

    p(z|x;\theta)=N(\mu_{NN}(x),\sigma _{NN}^{2}(x))

    变分推断(使用神经网络表示)的实质

    根据背景介绍中的(3)式,我们可知E_{z\sim q}[log\frac{p(x|z)p(z)}{q(z|x)}]log(p(x))的LOWER BOUND,为了使log(p(x))最大,我们可以通过调整LOWER BOUND来使得log(p(x))最大化,但是由于log(p(x))并不依赖于q(z|x),只通过调整q(z|x)可以使得LOWER BOUND增大,但并不保证log(p(x))最大化。因此,为了使log(p(x))最大化,应该对p(x|z)和q(z|x)调整使得L(q(z|x),p(x|z))最大化。

    L(q(z|x), p(x|z))=E_{z\sim q}[log\frac{p(x|z)p(z)}{q(z|x)}]=E_{z\sim q}[log(p(x|z))+log\frac{p(z)}{q(z|x)} ]=E_{z\sim q}[log(p(x|z))]-D_{KL}(q(z|x)\|p(z))

    分别使用两个神经网络来表示q(z|x),p(x|z),参数分别为φ、θ,可以形式化表达为优化问题:

    \theta^*,\phi^*=argmax_{\theta,\phi}L(q_{\phi}(z|x),p_{\theta}(x|z))=argmax_{\theta,\phi}(E_{z\sim q}[log(p_{\theta}(x|z))]-D_{KL}(q_{\phi}(z|x)\|p(z)))

    为了找到\theta^*,\phi^*,使用梯度上升法,通过调整\theta,\phi使得L(q_{\phi}(z|x),p_{\theta}(x|z))增大。

    其中\nabla_{\phi}E_{z\sim q}[log(p_{\theta}(x|z))]=\nabla_{\phi}\int_{z}q_{\phi}(z|x)log(p_{\theta}(x|z)dz=\int_{z}\nabla_{\phi}q_{\phi}(z|x)log(p_{\theta}(x|z)dz

    =\int_{z}q_{\phi}(z|x)\nabla_{\phi}log(q_{\phi}(z|x))log(p_{\theta}(x|z)dz

    =E_{z\sim q}[\nabla_{\phi}log(q_{\phi}(z|x))log(p_{\theta}(x|z)]

    另外,由于z\sim N(\mu_{\phi}(x),\sigma_{\phi}^2(x)),令z=\mu_{\phi}(x)+\epsilon\sigma_{\phi}(x),其中\epsilon \sim N(0,1)

    则有E_{z\sim q}[log(p_{\theta}(x|z))]=E_{\epsilon \sim N(0,1)}[log(p_{\theta}(x|\mu_{\phi}(x)+\epsilon\sigma_{\phi}(x)))]\approx log(p_{\theta}(x|\mu_{\phi}(x)+\tilde{\epsilon} \sigma_{\phi}(x)))

    从下图可以看出,\nabla_{\theta}L(q_{\phi}(z|x),p_{\theta}(x|z))只作用到网络的后半部分,黄色部分标记。

    对于\nabla_{\phi}L(q_{\phi}(z|x),p_{\theta}(x|z))作用到整体网络,黄色部分标记,同时需要求得\nabla_{\phi}(-D_{KL}(q_{\phi}(z|x)\|p(z))))

    相关文章

      网友评论

          本文标题:VI

          本文链接:https://www.haomeiwen.com/subject/rnfkhqtx.html