美文网首页程序员人工智能
变分推断|机器学习推导系列(十四)

变分推断|机器学习推导系列(十四)

作者: 酷酷的群 | 来源:发表于2020-08-29 09:59 被阅读0次

    一、概述

    对于概率模型来说,如果从频率派角度来看就会是一个优化问题,从贝叶斯角度来看就会是一个积分问题

    从贝叶斯角度来看,如果已有数据x,对于新的样本\hat{x},需要得到:

    p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}\theta =\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}\theta \\ \overset{\hat{x}与x独立}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}\theta =E_{\theta |x}[p(\hat{x}|\theta )]

    如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。推断问题的中⼼是参数后验分布的求解,推断分为:

    1. 精确推断
    2. 近似推断-参数空间无法精确求解
      ①确定性近似-如变分推断
      ②随机近似-如 MCMC,MH,Gibbs

    二、公式导出

    有以下数据:

    x:observed variable\rightarrow X:\left \{x_{i}\right \}_{i=1}^{N}
    z:latent variable + parameter\rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N}
    (X,Z):complete data

    我们记z为隐变量和参数的集合。接着我们变换概率p(x)的形式然后引入分布q(z)

    log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)}

    式子两边同时对q(z)求积分:

    左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}}

    我们的目的是找到一个使得qp更接近,也就是使KL(q||p)越小越好,也就是要使L(q)越大越好:

    \tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x)

    在变分推断中我们对q(z)做以下假设(基于平均场假设的变分推断),也就是说我们把多维变量的不同维度分为M组,组与组之间是相互独立的:

    q(z)=\prod_{i=1}^{M}q_{i}(z_{i})

    求解时我们固定q_{i}(z_{i}),i\neq j来求q_{j}(z_{j}),接下来将L(q)写作两部分:

    L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}

    对于①:

    ①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}

    对于②:

    ②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z\\ 其中\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\\ 也就是说\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}\\ 则②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C

    然后我们可以得到①-②\;

    首先①=\int _{z_{j}}q_{j}(z_{j})\cdot\underset{写作log\; \hat{p}(x,z_{j})}{ \underbrace{E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]}}\cdot \mathrm{d}z_{j}\\ 然后①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0

    q_{j}(z_{j})=\hat{p}(x,z_{j})才能得到最⼤值。

    三、回顾EM算法

    回想一下广义EM算法中,我们需要固定\theta然后求解与p最接近的q,这里就可以使用变分推断的方法,我们有如下式子:

    log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q)

    然后求解q

    \hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q)

    使用上述平均场变分推断的话,我们就可以得出以下结果(注意这里z_i不是代表z的第i个维度):

    log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{M}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{M}

    一次迭代求解的过程如下:

    log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{M}}q_{2}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{M}}\hat{q}_{1}q_{3}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{M}\\ \vdots \\ log\; \hat{q}_{M}(z_{M})=\int _{z_{1}}\cdots \int _{z_{M-1}}\hat{q}_{1}\cdots \hat{q}_{M-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{M-1}

    我们看到,对每⼀个q_{j}(z_{j}),都是固定其余的q_{i}(z_{i}),求这个值,于是可以使⽤坐标上升的⽅法进⾏迭代求解,上⾯的推导针对单个样本,但是对数据集也是适⽤的。

    基于平均场假设的变分推断存在⼀些问题:
    ①假设太强,⾮常复杂的情况下,假设不适⽤;
    ②期望中的积分,可能⽆法计算。

    四、随机梯度变分推断(SGVI)

    1. 直接求导数的方法

    ZX的过程叫做⽣成过程或译码,从XZ过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在⼀些情况下假设太强,同时积分也不⼀定能算。我们知道,优化⽅法除了坐标上升,还有梯度上升的⽅式,我们希望通过梯度上升来得到变分推断的另⼀种算法。

    假定q(Z)=q_{\phi }(Z),是和\phi这个参数相连的概率分布。于是\underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ),其中L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)],这里的x表示的是样本。

    \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}\\ 其中②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0\\ 因此\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]

    这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:

    z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))

    但是由于求和符号中存在⼀个对数项,于是直接采样的⽅差很⼤,需要采样的样本⾮常多。为了解决⽅差太⼤的问题,我们采⽤重参数化技巧(Reparameterization)。

    1. 重参数化技巧

    我们取z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ),对于z\sim q_{\phi }(z|x),我们有\left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right |。代入上面的梯度中:

    \nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)]

    对这个式⼦进⾏蒙特卡洛采样,然后计算期望,得到梯度。

    SGVI的迭代过程为:

    \phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi )

    相关文章

      网友评论

        本文标题:变分推断|机器学习推导系列(十四)

        本文链接:https://www.haomeiwen.com/subject/xhmnjktx.html