Pyro简介:产生式模型实现库(三),SVI 一

作者: WilliamY | 来源:发表于2019-11-26 16:32 被阅读0次

问题设定

我们在前面的教程中,用Pyro定义过model函数(过程如简介(一))。这里快速回忆一下model的用法model(*args, **kwargs),在model的参数中,包括下面三要素

  1. 观察 \Longleftrightarrow 带有obspyro.sample
  2. 隐变量\Longleftrightarrow pyro.sample
  3. 模型参数 \Longleftrightarrowpyro.sample

现在我们定义符号:观察变量\bf x,隐变量\bf z,参数变量\bf \theta,联合分布
p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})
联合分布的分解方式是灵活的,分解后的每个函数p_i只要符合3个条件即可。

  1. 每个p_i都可以采样;
  2. 每个p_i在每个样本点上都可以计算对数概率密度 \log \rm{pdf} \it p_i
  3. 每个p_i对于\theta都是可导的。

模型(参数)的学习

我们对于优秀模型的标准,在于最大化证据的对数似然,即寻找适当的\theta
\theta_{\rm{max}} = \underset{\theta}{\operatorname{argmax}} \log p_{\theta}({\bf x})
证据的对数似然为:
\log p_{\theta}({\bf x}) = \log \int\! d{\bf z}\; p_{\theta}({\bf x}, {\bf z})
上述形式化引发了双重困难:一、对于\bf z的积分,往往没有解析的办法(即使固定了\theta也不行);二、哪怕对\bf z的积分成功了,即我们可以计算任何一个\theta点上的对数似然,对\theta取最大的对数似然值是一个非凸优化问题,仍旧十分困难。
除了要寻找\theta_{\rm{max}},我们还要计算对\bf z的后验概率:
p_{\theta_{\rm{max}}}({\bf z} | {\bf x}) = \frac{p_{\theta_{\rm{max}}}({\bf x} , {\bf z})}{ \int \! d{\bf z}\; p_{\theta_{\rm{max}}}({\bf x} , {\bf z}) }
上式分母就是证据的概率密度函数(也被叫做“配分函数”),它往往也没有解析形式。变分推断的任务,就是既要寻找\theta_{\rm{max}},又要计算后验概率p_{\theta_{\rm{max}}}({\bf z} | {\bf x})

guide函数

变分推断最基本的想法,是利用另一个参数化的概率分布函数q_\phi(\bf z)来近似后验概率p_{\theta_{\rm{max}}}({\bf z} | {\bf x})。这个q_\phi(\bf z)被称为变分分布,其参数为\phi,在Pyro中我们叫它guide
和model一样,guide()也可以进行pyro.samplepyro.param操作。guide中包含观察变量,因为它要恰当地归一化。guide()model()具有相同的调用结构,即二者具有相同的输入参数(argument)。
为了近似后验分布p_{\theta_{\rm{max}}}({\bf z} | {\bf x}),guide需要提供联合分布形式,这就需要配准guide和model的变量,确保二者使用的变量是一致的。在上一教程里,我们讲到pyro.sample()第一次声明随机变量时,容器中将存储该变量的键值对。利用这一机制,我们就可以在guide和model中使用相同的声明格式,确保二者的统一。举例来说,比如我们要声明随机变量z_1,在model中我们输入:

def model():
    pyro.sample('z_1', ...)

在guide中我们采用相同的格式:

def guide():
    pyro.sample('z_1', ...)

二者的分布形式可以是不同的,但名字必须1对1地配准。
一旦确定了guide,我们就可以进行推断了。学习参数的问题,被设定为在\theta-\phi空间搜索最优值的问题,学习过程将引导guide越来越接近后验分布的真值。下一部分我们介绍优化的目标,即损失函数。

ELBO(对数证据下界)

ELBO(evidence lower bound)被定义为\theta-\phi的函数,它是guide函数采样得到的期望:
{\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [\log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z})\right]
上式中,p_{\theta}({\bf x}, {\bf z})q_{\phi}({\bf z})的模型是已知的,我们采样它们并用蒙特卡罗法计算结果。
之所以ELBO取名为证据下界,因为对于所有的\theta\phi来说,不等式是严格成立的:
\log p_{\theta}({\bf x}) \ge {\rm ELBO}
所以最大化ELBO,我们将推高证据的期望。ELBO和证据之间的差异为:
\log p_{\theta}(\bf x) - \rm ELBO = \rm KL(q_\phi(\bf z) || p_\theta(\bf z|\bf x))
上式的KL距离衡量两个分布间的近似程度。
在优化过程中\theta\phi被计算梯度,并沿着梯度引导目标函数下降,使guide动态地“追逐”\log p_\theta(\bf z |\bf x)。即使后者是动态变化的,对大多数问题来说,优化过程仍旧是有效的。这里存在一个问题,怎样计算ELBO的梯度,我们在后续教程中继续讲解。现在介绍着重用于优化随机变分推断的ELBO的类SVI。

SVI

在Pyro中,和变分推断相关的操作集成在SVI类中。已经实现的Pyro代码只支持ELBO作为目标函数,其他类型的目标函数在将来完成。
用户需要提供给SVI三个输入:model、guide、优化器。假定我们已经定义好了三者,用户调用SVI类只需要如是声明:

import pyro
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

SVI自带两种方法:step()evaluate_loss()

  1. step()执行单步梯度下降,返回损失的估计值(即负ELBO)。如果step带有参数,这些参数必须是model()guide()声明的格式。
  2. evaluate_loss()返回损失的估计值,但不执行梯度下降。和step一样,如果带有参数,这些参数必须是model()guide()声明的格式。
    二者还可输入一个可选的参数num_particals,来指定计算期望时的采样次数。对于step()计算损失和梯度,对于evaluate_loss()只计算损失。

优化器

SVI的三个输入,我们只剩优化器需要进一步介绍。回顾model和guide,它们需要满足:

  1. guide函数中的pyro.sample不能带有obs参数。
  2. modelguide具有相同的参数名称(或称“签名”)。

这里就引发了一个困难,model()guide()的行为可能相当不同。举例来说,某个随机变量只在一些时候出现,参数在优化的过程中是动态采样得到的。这就可能导致\theta\phi在动态中剧烈振荡。

为了规避这一问题,Pyro需要对每一个新出现的参数,动态生成对应的优化器。幸运的是,Pytorch已经实现了一个轻量级的优化器(详见torch.optim),可以轻松地改为支持动态的情况。
optim.PyroOptim类打包了Pytorch的优化器。PyroOptim输入两个参数:Pytorch优化器的构造器optim_constructor、优化器自己参数optim_args。每当optim_constructor实例化一个优化器,optim_args就提供一组参数。
大部分用户都不希望直接操作PyroOptim,而仅仅通过定义在optim/__init__.py中的名称与优化器交互操作。这类操作有两种实现方法。比较简单的情况下,对于所有的优化器,我们采用相同的优化参数optim_args

from pyro.optim import Adam

adam_params = {'lr': 0.005, 'betas': (0.95, 0.999)}
optimizer = Adam(adam_params)

第二种方法允许用户对优化器进行细致的控制。举个简单的例子:

from pyro.optim import Adam

def per_param_callable(module_name, param_name):
    if param_name == 'my_special_parameter':
        return {'lr': 0.010}
    else:
        return {'lr': 0.001}

optimizer = Adam(per_param_callable)

上面的例子中,对与my_special_parameter这些参数,Pyro用户设置的学习率为0.01,除此之外的参数学习率为0.001。

一个综合上述几点的例子

假如你有一枚硬币,在投掷硬币的实验中观察硬币的正反,计数正面(heads)和反面(tails)的次数。你的先验知识是该硬币是公平的(即没有偏向性的),你将在实验中观察结果,并据其修正观点。
解释一下,公平的意思,是正反面出现的次数差不多。如果正反次数的比例为11:10,你不会感到奇怪;但如果正反比为5:1,你将非常惊讶。
我们将正面记为1反面记为0,硬币的公平性记为ff满足f \in [0.0, 1.0]f=0.5表示硬币完全公平。我们的先验认为f服从beta分布\rm{Beta}(10,10),该分布在[0, 1]上的图像为对称的钟形曲线,峰值为f=0.5

Beta分布表示了我们对硬币的先验知识
假如我们抛掷了10次硬币,将结果存放在data中(数据类型是list)。我们定义model如下:
import pyro.distributions as dist

def model(data):
    # 定义beta分布的超参
    alpha0 = torch.tensor(10.)
    beta0 = torch.tensor(10.)
    # 从先验分布采样
    f = pyro.sample('latent_fairness', dist.Beta(alpha0, beta0))
    # 循环所有的观察数据
    for i in range(len(data)):
        # 观察的数据点 i 服从伯努利分布
        # 似然为 Bernoulli(f)
        pyro.sample('obs_{}'.format(i), dist.Bernoulli(f), obs=data[i])

这里我们定义了隐变量latent_fairness,它服从\rm Beta(10, 10)。每个样本点在给定隐变量的条件概率下,(即似然函数)服从伯努利分布。注意到每个观察数据都注册了单独的名字obs_i

下面我们定义相对应的guide函数,即对于隐变量f做近似的变分分布。对于q(f)只需要满足阈值在[0.0,1.0]之间即可。一个简单的选择是选取两个可学习的参数\alpha_q\beta_q,这是因为伯努利分布和beta分布是共轭分布,后验概率仍服从beta分布。我们这样写:

def guide(data):
    # 注册两个变分参数
    alpha_q = pyro.param('alpha_q', torch.tensor(15.), constraint = constraints.positive)
    beta_q  = pyro.param('alpha_q', torch.tensor(15.), constraint = constraints.positive)
    # 采样Beta(alpha_q, beta_q)得到latent_fairness
    pyro.sample('latent_fairness', dist.Beta(alpha_q, beta_q))

需要注意以下几点:

  • guide和model的随机变量名称必须一致;
  • model(data)guide(data)的参数形式必须一致;
  • 变分参数必须是torch.tensor类型,requires_grad属性被pyro.param自动设为true
  • constraint=constaints.positive保证了alpha_qbeta_q在优化过程中保持非负性。
    下面我们开始变分推断。
# 设定优化器参数
adam_param = {'lr': 0.005, 'betas':(0.9, 0.999)}
optimizer = Adam(adam_param)

# 设定推断算法
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 500
# 执行梯度下降
for i in range(n_steps):
    svi.step(data)

step()方法中,data分别被传到model和guide中。
我们模拟data的情况,并将全部代码补充完整。

import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# 该声明用于测试
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

# 允许验证(比如验证参数的分布
pyro.enable_validation(True)

# 清空参数的容器
pyro.clear_param_store()

# 创建观察数据。这里假定实验结果为,前6次为正,后4次为反
data = list()
for _ in range(6):
    data.append(torch.tensor(1.))

for _ in range(4):
    data.append(torch.tensor(0.))

def model(data):
    # 先验的 beta 分布的超参
    alpha0 = torch.tensor(10.)
    beta0 = torch.tensor(10.)
    # 从先验分布中采样f
    f = pyro.sample('latent_fairness', dist.Beta(alpha0, beta0))
    # 遍历整个观察数据集
    for i in range(len(data)):
        # 似然函数在数据点i服从伯努利分布
        pyro.sample('obs_{}'.format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # 在Pyro中注册变分分布的参数
    # 两个参数值均为15.0
    # 我们对没有约束的参数采用梯度下降
    # 注意,这里是pyro.param,不是pyro.sample!!!
    alpha_q = pyro.param('alpha_q', torch.tensor(15.), constraint=constraints.positive)
    beta_q = pyro.param('beta_q', torch.tensor(15.), constraint=constraints.positive)
    # 从Beta(alpha_q, beta_q)采样得到latent_fairness
    pyro.sample('latent_fairness', dist.Beta(alpha_q, beta_q))

# 设置优化器参数
adam_params = {'lr': 0.0005, 'betas': (0.9, 0.999)}
optimizer = Adam(adam_params)

# 设置推断算法
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# 梯度下降
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end=' ')

# 监听变分参数的值
alpha_q = pyro.param('alpha_q').item()
beta_q = pyro.param('beta_q').item()

# 根据beta分布的特点, 我们计算推断出的公平性系数
inferred_mean = alpha_q / (alpha_q + beta_q)
# 计算其标准差
factor = beta_q / (alpha_q * (1. + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print('\nbased on the data and our prior belief, the fairness '+
        'of the coin is %.3f +- %.3f' % (inferred_mean, inferred_std))
采样结果:
based on the data and our prior belief, the fairness of the coin is 0.532 +- 0.090

根据贝塔分布的公式,后验分布的均值为16÷30=0.53。上面给出的结果和解析计算的结果是一样的。我们看到,0.53在先验信念0.5和经验频率0.6之间。

在后面的教程中,我们将继续介绍SVI类,敬请期待!

相关文章

网友评论

    本文标题:Pyro简介:产生式模型实现库(三),SVI 一

    本文链接:https://www.haomeiwen.com/subject/uepjwctx.html