美文网首页程序员人工智能
变分推断|机器学习推导系列(十四)

变分推断|机器学习推导系列(十四)

作者: 酷酷的群 | 来源:发表于2020-08-29 09:59 被阅读0次

一、概述

对于概率模型来说,如果从频率派角度来看就会是一个优化问题,从贝叶斯角度来看就会是一个积分问题

从贝叶斯角度来看,如果已有数据x,对于新的样本\hat{x},需要得到:

p(\hat{x}|x)=\int _{\theta }p(\hat{x},\theta |x)\mathrm{d}\theta =\int _{\theta }p(\hat{x}|\theta ,x)p(\theta |x)\mathrm{d}\theta \\ \overset{\hat{x}与x独立}{=}\int _{\theta }p(\hat{x}|\theta)p(\theta |x)\mathrm{d}\theta =E_{\theta |x}[p(\hat{x}|\theta )]

如果新样本和数据集独立,那么推断就是概率分布依参数后验分布的期望。推断问题的中⼼是参数后验分布的求解,推断分为:

  1. 精确推断
  2. 近似推断-参数空间无法精确求解
    ①确定性近似-如变分推断
    ②随机近似-如 MCMC,MH,Gibbs

二、公式导出

有以下数据:

x:observed variable\rightarrow X:\left \{x_{i}\right \}_{i=1}^{N}
z:latent variable + parameter\rightarrow Z:\left \{z_{i}\right \}_{i=1}^{N}
(X,Z):complete data

我们记z为隐变量和参数的集合。接着我们变换概率p(x)的形式然后引入分布q(z)

log\; p(x)=log\; p(x,z)-log\; p(z|x)=log\; \frac{p(x,z)}{q(z)}-log\; \frac{p(z|x)}{q(z)}

式子两边同时对q(z)求积分:

左边=\int _{z}q(z)\cdot log\; p(x |\theta )\mathrm{d}z=log\; p(x|\theta )\int _{z}q(z )\mathrm{d}z=log\; p(x|\theta )\\ 右边=\underset{ELBO(evidence\; lower\; bound)}{\underbrace{\int _{z}q(z)log\; \frac{p(x,z|\theta )}{q(z)}\mathrm{d}z}}\underset{KL(q(z)||p(z|x,\theta ))}{\underbrace{-\int _{z}q(z)log\; \frac{p(z|x,\theta )}{q(z)}\mathrm{d}z}}\\ =\underset{变分}{\underbrace{L(q)}} + \underset{\geq 0}{\underbrace{KL(q||p)}}

我们的目的是找到一个使得qp更接近,也就是使KL(q||p)越小越好,也就是要使L(q)越大越好:

\tilde{q}(z)=\underset{q(z)}{argmax}\; L(q)\Rightarrow \tilde{q}(z)\approx p(z|x)

在变分推断中我们对q(z)做以下假设(基于平均场假设的变分推断),也就是说我们把多维变量的不同维度分为M组,组与组之间是相互独立的:

q(z)=\prod_{i=1}^{M}q_{i}(z_{i})

求解时我们固定q_{i}(z_{i}),i\neq j来求q_{j}(z_{j}),接下来将L(q)写作两部分:

L(q)=\underset{①}{\underbrace{\int _{z}q(z)log\; p(x,z)\mathrm{d}z}}-\underset{②}{\underbrace{\int _{z}q(z)log\; q(z)\mathrm{d}z}}

对于①:

①=\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; p(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{j}}q_{j}(z_{j})\underset{\int _{z-z_{j}}log\; p(x,z)\prod_{i\neq j}^{M}q_{i}(z_{i})\mathrm{d}z_{i}}{\underbrace{\left (\int _{z-z_{j}}\prod_{i\neq j}^{M}q_{i}(z_{i})log\; p(x,z)\underset{(i\neq j)}{\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}}\right )}}\mathrm{d}z_{j}\\ =\int _{z_{j}}q_{j}(z_{j})\cdot E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]\cdot \mathrm{d}z_{j}

对于②:

②=\int _{z}q(z)log\; q(z)\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})\sum_{i=1}^{M}log\; q_{i}(z_{i})\mathrm{d}z\\ =\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})[log\; q_{1}(z_{1})+log\; q_{2}(z_{2})+\cdots +log\; q_{M}(z_{M})]\mathrm{d}z\\ 其中\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{1}(z_{1})\mathrm{d}z\\ =\int _{z_{1}z_{2}\cdots z_{M}}q_{1}(z_{1})q_{2}(z_{2})\cdots q_{M}(z_{M})\cdot log\; q_{1}(z_{1})\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\cdot \underset{=1}{\underbrace{\int _{z_{2}}q_{2}(z_{2})\mathrm{d}z_{2}}}\cdot \underset{=1}{\underbrace{\int _{z_{3}}q_{3}(z_{3})\mathrm{d}z_{3}}}\cdots \underset{=1}{\underbrace{\int _{z_{M}}q_{M}(z_{M})\mathrm{d}z_{M}}}\\ =\int _{z_{1}}q_{1}(z_{1})log\; q_{1}(z_{1})\mathrm{d}z_{1}\\ 也就是说\int _{z}\prod_{i=1}^{M}q_{i}(z_{i})log\; q_{k}(z_{k})\mathrm{d}z=\int _{z_{k}}q_{k}(z_{k})log\; q_{k}(z_{k})\mathrm{d}z_{k}\\ 则②=\sum_{i=1}^{M}\int _{z_{i}}q_{i}(z_{i})log\; q_{i}(z_{i})\mathrm{d}z_{i}\\ =\int _{z_{j}}q_{j}(z_{j})log\; q_{j}(z_{j})\mathrm{d}z_{j}+C

然后我们可以得到①-②\;

首先①=\int _{z_{j}}q_{j}(z_{j})\cdot\underset{写作log\; \hat{p}(x,z_{j})}{ \underbrace{E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p(x,z)]}}\cdot \mathrm{d}z_{j}\\ 然后①-②=\int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}+C\\ \int _{z_{j}}q_{j}(z_{j})\cdot log\frac{\hat{p}(x,z_{j})}{q_{j}(z_{j})}\mathrm{d}z_{j}=-KL(q_{j}(z_{j})||\hat{p}(x,z_{j}))\leq 0

q_{j}(z_{j})=\hat{p}(x,z_{j})才能得到最⼤值。

三、回顾EM算法

回想一下广义EM算法中,我们需要固定\theta然后求解与p最接近的q,这里就可以使用变分推断的方法,我们有如下式子:

log\; p_{\theta }(x)=\underset{L(q)}{\underbrace{ELBO}}+\underset{\geq 0}{\underbrace{KL(q||p)}}\geq L(q)

然后求解q

\hat{q}=\underset{q}{argmin}\; KL(q||p)=\underset{q}{argmax}\; L(q)

使用上述平均场变分推断的话,我们就可以得出以下结果(注意这里z_i不是代表z的第i个维度):

log\; q_{j}(z_{j})=E_{\prod_{i\neq j}^{M}q_{i}(z_{i})}[log\; p_{\theta }(x,z)]\\ =\int _{z_{1}}\int _{z_{2}}\cdots \int _{z_{j-1}}\int _{z_{j+1}}\cdots \int _{z_{M}}q_{1}q_{2}\cdots q_{j-1}q_{j+1}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{2}\cdots \mathrm{d}z_{j-1}\mathrm{d}z_{j+1}\cdots \mathrm{d}z_{M}

一次迭代求解的过程如下:

log\; \hat{q}_{1}(z_{1})=\int _{z_{2}}\cdots \int _{z_{M}}q_{2}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{2}\cdots \mathrm{d}z_{M}\\ log\; \hat{q}_{2}(z_{2})=\int _{z_{1}}\int _{z_{3}}\cdots \int _{z_{M}}\hat{q}_{1}q_{3}\cdots q_{M}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\mathrm{d}z_{3}\cdots \mathrm{d}z_{M}\\ \vdots \\ log\; \hat{q}_{M}(z_{M})=\int _{z_{1}}\cdots \int _{z_{M-1}}\hat{q}_{1}\cdots \hat{q}_{M-1}\cdot log\; p_{\theta }(x,z)\mathrm{d}z_{1}\cdots \mathrm{d}z_{M-1}

我们看到,对每⼀个q_{j}(z_{j}),都是固定其余的q_{i}(z_{i}),求这个值,于是可以使⽤坐标上升的⽅法进⾏迭代求解,上⾯的推导针对单个样本,但是对数据集也是适⽤的。

基于平均场假设的变分推断存在⼀些问题:
①假设太强,⾮常复杂的情况下,假设不适⽤;
②期望中的积分,可能⽆法计算。

四、随机梯度变分推断(SGVI)

  1. 直接求导数的方法

ZX的过程叫做⽣成过程或译码,从XZ过程叫推断过程或编码过程,基于平均场的变分推断可以导出坐标上升的算法,但是这个假设在⼀些情况下假设太强,同时积分也不⼀定能算。我们知道,优化⽅法除了坐标上升,还有梯度上升的⽅式,我们希望通过梯度上升来得到变分推断的另⼀种算法。

假定q(Z)=q_{\phi }(Z),是和\phi这个参数相连的概率分布。于是\underset{q(Z)}{argmax}\; L(q)=\underset{\phi }{argmax}\; L(\phi ),其中L(\phi )=E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)],这里的x表示的是样本。

\nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\underset{①}{\underbrace{\int \nabla_{\phi }q_{\phi }(z)\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}+\underset{②}{\underbrace{\int q_{\phi }(z)\nabla_{\phi }[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z}}\\ 其中②=\int q_{\phi }(z)\nabla_{\phi }[\underset{与\phi 无关}{\underbrace{log\; p_{\theta }(x,z)}}-log\; q_{\phi }(z)]\mathrm{d}z\\ =-\int q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)\mathrm{d}z\\ =-\int q_{\phi }(z)\frac{1}{q_{\phi }(z)}\nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\int \nabla_{\phi }q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }\int q_{\phi }(z)\mathrm{d}z\\ =-\nabla_{\phi }1\\ =0\\ 因此\nabla_{\phi }L(\phi )=①\\ =\int {\color{Red}{\nabla_{\phi }q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\int {\color{Red}{q_{\phi }(z)\nabla_{\phi }log\; q_{\phi }(z)}}\cdot [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]

这个期望可以通过蒙特卡洛采样来近似,从⽽得到梯度,然后利⽤梯度上升的⽅法来得到参数:

z^{l}\sim q_{\phi }(z)\\ E_{q_{\phi }}[(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\sim \frac{1}{L}\sum_{i=1}^{L}(\nabla_{\phi }log\; q_{\phi }(z))(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))

但是由于求和符号中存在⼀个对数项,于是直接采样的⽅差很⼤,需要采样的样本⾮常多。为了解决⽅差太⼤的问题,我们采⽤重参数化技巧(Reparameterization)。

  1. 重参数化技巧

我们取z=g_{\phi }(\varepsilon ,x),\varepsilon \sim p(\varepsilon ),对于z\sim q_{\phi }(z|x),我们有\left | q_{\phi }(z|x)\mathrm{d}z \right |=\left | p(\varepsilon )\mathrm{d}\varepsilon \right |。代入上面的梯度中:

\nabla_{\phi }L(\phi )=\nabla_{\phi }E_{q_{\phi }}[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =\nabla_{\phi }\int q_{\phi }(z)[log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]q_{\phi }(z)\mathrm{d}z\\ =\nabla_{\phi }\int [log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]p(\varepsilon )\mathrm{d}\varepsilon \\ =\nabla_{\phi }E_{p(\varepsilon )}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z)]\\ =E_{p(\varepsilon )}[\nabla_{\phi }(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }z]\\ =E_{p(\varepsilon )}[\nabla_{z}(log\; p_{\theta }(x,z)-log\; q_{\phi }(z))\nabla_{\phi }g_{\phi }(\varepsilon ,x)]

对这个式⼦进⾏蒙特卡洛采样,然后计算期望,得到梯度。

SGVI的迭代过程为:

\phi ^{t+1}\leftarrow \phi ^{t}+\lambda ^{t}\cdot \nabla_{\phi }L(\phi )

相关文章

  • 变分推断|机器学习推导系列(十四)

    一、概述 对于概率模型来说,如果从频率派角度来看就会是一个优化问题,从贝叶斯角度来看就会是一个积分问题。 从贝叶斯...

  • 变分推断的原理推导

    一.原理推导 变分推断(VI)要做的事情很朴素,那就是有一个复杂的难以求解的分布,比如后验概率分布:,这里表示观测...

  • 概率图模型-推断|机器学习推导系列(十一)

    一、概述 总的来说,推断的任务就是求概率。假如我们知道联合概率,我们需要使用推断的方法来求: 以下是一些推断的方法...

  • 近似推断|机器学习推导系列(二十七)

    一、推断的动机和困难 推断的动机 推断问题是在概率图模型中经常遇到的问题,也就是给定观测变量的情况下求解后验,这里...

  • 变分推断

    作者:知乎用户 链接:https://www.zhihu.com/question/41765860/answer...

  • 变分推断

    在github看到这个文章写的不错,就转载了,大家一起学习:https://github.com/keithyin...

  • 变分推断

    变分推断(Variational Inference) 一文读懂贝叶斯推理问题:MCMC方法和变分推断

  • 数学知识

    变分推断1

  • python数据分析与机器学习(Numpy,Pandas,Mat

    机器学习怎么学? 机器学习包含数学原理推导和实际应用技巧,所以需要清楚算法的推导过程和如何应用。 深度学习是机器学...

  • 绪论|机器学习推导系列(一)

    一、频率派 vs 贝叶斯派 机器学习主要解决从数据中获取其概率分布的问题,通过一些机器学习的算法可以从大量数据中找...

网友评论

    本文标题:变分推断|机器学习推导系列(十四)

    本文链接:https://www.haomeiwen.com/subject/xhmnjktx.html