Pyro进阶(二):贝叶斯回归问题 一

作者: WilliamY | 来源:发表于2019-12-03 15:08 被阅读0次

    回归是一类常见的有监督的机器学习任务。假如数据集合为\mathcal{D}
    \mathcal{D} = \{(X_i, y_i)\} \qquad for \qquad i=1,2,...,N
    目标是拟合线性回归方程:
    y = wX + b + \epsilon
    其中wb是参数,分别代表权重和偏置;\epsilon为噪声。
    在本教程中,我们分别使用Pytorch和Pyro求解wb,Pyro的方法为贝叶斯回归。我们将学习使用Pyro的预测功能。

    问题定义

    首先我们引入必要的头文件

    %reset -s -t #重启IPython的kernel;如果不使用IPython Notebook,请忽略
    import os
    import torch
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from functools import partial
    
    import pyro
    import pyro.distributions as dist
    
    # 测试的flag
    smoke_test = ('CI' in os.environ)
    pyro.enable_validation(True)
    pyro.set_rng_seed(1)
    
    #设置matplotlib,如果不使用IPython Notebook,请忽略
    %matplotlib inline
    plt.style.use('default')
    

    数据集

    我们希望研究一个国家的地形特征和其人均GDP之间的关系。根据一些文献的结论,地形崎岖不平将对非洲以外国家的经济状态产生负面影响,但对非洲国家的影响却是正面的。我们希望通过数据来探究这种关系。数据中三个变量是我们关心的:rugged表示某国地形的崎岖程度;cont_africa表示某国是否在非洲;rgdppc_2000表示在2000年某国的人均GDP。
    由于人均GDP差别较大,我们先将其取对数。

    DATA_URL = 'https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv'
    data = pd.read_csv(DATA_URL, encoding='ISO-8859-1')
    df = data[['rugged', 'cont_africa', 'rgdppc_2000']]
    df = df[np.isfinite(df.rgdppc_2000)]
    df['rgdppc_2000'] = np.log(df['rgdppc_2000'])
    
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
    african_nations = df[df['cont_africa'] == 1]
    non_african_nations = df[df['cont_africa'] == 0]
    sns.scatterplot(non_african_nations['rugged'], 
                    non_african_nations['rgdppc_2000'], 
                    ax=ax[0])
    ax[0].set(xlabel='Terrain Ruggedness Index',
              ylabel='log GDP (2000)'
              title='Non African Nations')
    sns.scatterplot(african_nations['rugged'],
                    african_nations['rgdppc_2000'],
                    ax=ax[1])
    ax[1].set(xlabel='Terrain Ruggedness Index',
              ylabel='log GDP (2000)',
              title='African Nations')
    

    线性回归

    我们构建模型的类叫做PyroModule[nn.Linear],它是PyroModule的一个子类,也是torch.nn.Linear的子类。PyroModule和Pytorch的类nn.Module是很类似的,只是能够支持Pyro primitives操作(包括paramsampleplate等),也支持通过effect handlers控制。这里有一些泛泛的注意事项:

    • 在Pytorch的模型中的可学习的参数,是nn.Parameter的实例。线性回归中,weightbiasnn.Linear类的参数的实例。PyroModule中定义的参数, 被存储在Pyro的参数容器中。尽管PyroModule并不要求限制这些参数的取值,要限制它们也可以通过PyroParam语句来轻松实现。
    • 尽管PyroModule[nn.Linear]forward方法继承于nn.Linear,这一方法可以很容易地改变nn.Linear的设定。例如在逻辑斯蒂回归中,我们将使用sigmoid函数来转换线性预测器的输出。
    from torch import nn
    from pyro.nn import PyroModule
    
    assert issubclass(PyroModule[nn.Linear], nn.Linear)
    assert issubclass(PyroModule[nn.Linear], PyroModule)
    

    使用Pytorch的优化器训练

    我们使用均方差(MSE)损失作为优化的目标函数,优化器为torch.optim模块。优化的过程将使weightbias向最优值拉近。

    # 增加一个变量来捕捉'cont_africa'和'rugged'的关系
    df['cont_africa_x_rugged'] = df['cont_africa'] * df['rugged']
    data = torch.tensor(df[['cont_africa', 'rugged', 'cont_africa_x_rugged', 'rgdppc_2000']].values, dtype=torch.float)
    x_data, y_data = data[:, :-1], data[:, -1]
    
    # 回归模型
    linear_reg_model = PyroModule[nn.Linear](3, 1)
    
    # 定义损失和优化器
    loss_fn = torch.nn.MSELoss(reduction='sum')
    optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)
    num_iterations = 1500 if not smoke_test else 2
    
    def train():
        # data前向传播
        y_pred = linear_reg_model(x_data).squeeze(-1)
        # 均方差损失
        loss = loss_fn(y_pred, y_data)
        # 将梯度置为零
        optim.zero_grad()
        # 梯度反传
        loss.backward()
        # 梯度更新一步
        optim.step()
        return loss
    
    for j in range(num_iterations):
        loss = train()
        if (j + 1) % 50 == 0:
            print('[iteration %04d] loss: %.4f' %(j + 1, loss.item()))
    
    # 监听参数的变化
    print('Learned parameters:')
    for name, param in linear_reg_model.named_parameters():
        print(name, param.data.numpy())
    

    [iteration 0050] loss: 3509.0015
    [iteration 0100] loss: 1952.4127
    [iteration 0150] loss: 1341.5183
    [iteration 0200] loss: 989.3399
    [iteration 0250] loss: 742.0116
    [iteration 0300] loss: 556.1625
    [iteration 0350] loss: 418.7539
    [iteration 0400] loss: 321.2036
    [iteration 0450] loss: 254.8719
    [iteration 0500] loss: 211.6130
    [iteration 0550] loss: 184.5197
    [iteration 0600] loss: 168.2088
    [iteration 0650] loss: 158.7643
    [iteration 0700] loss: 153.5028
    [iteration 0750] loss: 150.6820
    [iteration 0800] loss: 149.2266
    [iteration 0850] loss: 148.5042
    [iteration 0900] loss: 148.1591
    [iteration 0950] loss: 148.0007
    [iteration 1000] loss: 147.9307
    [iteration 1050] loss: 147.9010
    [iteration 1100] loss: 147.8890
    [iteration 1150] loss: 147.8842
    [iteration 1200] loss: 147.8825
    [iteration 1250] loss: 147.8819
    [iteration 1300] loss: 147.8816
    [iteration 1350] loss: 147.8816
    [iteration 1400] loss: 147.8815
    [iteration 1450] loss: 147.8815
    [iteration 1500] loss: 147.8815
    Learned parameters:
    weight [[-1.9478928 -0.20279995 0.39331874]]
    bias [9.223109]

    绘制回归结果

    下面我们分别对非洲国家和非洲以外国家的情况绘制回归的拟合结果。

    fit = df.copy()
    fit['mean'] = linear_reg_model(x_data).detach().cpu().numpy()
    fig, ax = plt.subplot(nrows=1, ncols=2, figsize=(12, 6), shary=True)
    african_nations = fit[fit['cont_africa'] == 1]
    non_african_nations = fit[fit['cont_africa'] == 0]
    fig.subtitle('Regression Fit', fontsize=16)
    ax[0].plot(non_african_nations['rugged'], non_african_nations['rgdppc_2000'], 'o')
    ax[0].plot(non_african_nations['rugged'], non_african_nations['mean'], linewidth=2)
    ax[0].set(xlabel="Terrain Ruggedness Index",
              ylabel="log GDP (2000)",
              title="Non African Nations")
    ax[1].plot(african_nations['rugged'], african_nations['rgdppc_2000'], 'o')
    ax[1].plot(african_nations['rugged'], african_nations['mean'], linewidth=2)
    ax[1].set(xlabel="Terrain Ruggedness Index",
              ylabel="log GDP (2000)",
              title="African Nations")
    

    我们注意到,在非洲以外的国家,国土崎岖程度越大,国民的人均生产总值越低;而在非洲国家,情况刚好反过来。我们希望了解这一结论的可靠性,即当参数在不确定区域变化时,这一结论是否会翻转。基于这个原因,我们搭建另外一种模型,贝叶斯线性回归模型,贝叶斯建模能够有效地给出模型的不确定性。与刚才模型的点估计不同,贝叶斯模型将给出基于观察的参数的分布

    Pyro的随机变分分布和贝叶斯回归

    Model

    为了让线性回归问题纳入贝叶斯体系中,我们须给wb以先验。也就是说,在未观察到任何数据前,我们对权重和偏置做合理的猜测。
    我们将使用PyroModule[nn.Linear]类组建模型。注意下面几点:

    • BaysianRegression模块在内部和PyroModule[nn.Linear]模块的方法是一样的。不过我们不直接使用weightbias,而是PyroSample模块。这样我们可以给weightbias指定先验,而非认定它们是固定的学习参数。对于偏置,我们给先验设置较大的范围。
    • BayesianRegression.forward方法进行产生过程。我们产生linear模块的激励产生均值(也就是采样先验得到weightbias并返回平均激励),利用obs参数指定pyro.sample的观察变量y_data,其学习的参数是观察变量的方差sigma。模型返回贝叶斯回归的结果mean
    from pyro.nn import PyroSample
    
    class BayesianRegression(PyroModule):
        def __init__(self, in_features, out_features):
            super().__init__()
            self.linear = PyroModule[nn.Linear](in_features, out_features)
            self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
            self.linear.bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))
        #
        def forward(self, x, y=None):
            sigma = pyro.sample('sigma', dist.Uniform(0., 10.))
            mean = self.linear(x).squeeze(-1)
            with pyro.plate('data', x.shape[0]):
                obs = pyro.sample('obs', dist.Normal(mean, sigma), obs=y)
            return mean
    

    使用自动生成的guide函数AutoGuide

    为了实现变分推断,我们需要定义后验概率的近似估计函数guide。guide规定了函数的大致范围,SVI将在其中选取使KL散度最低的的那一组参数作为后验估计的结果。
    在Pyro中,用户可以根据需求编写自己的guide,而现在我们使用Pyro的自动guide生成器autoguide library。在下一份教程中,我们再学习手写guide函数。
    我们可以选择AutoDiagonalNormal作为guide函数,该函数规定,非观察随机变量服从多元高斯分布,协方差矩阵是对角阵,即认为不同变量间不存在互相依赖关系。(这是一个很强的假设。)在这一情景中,我们定义guide对可学习的参数在Normal正态分布中采样sample。例如我们的采样大小是(5,),对应于3个回归系数、1个‘阻截项’、1个sigma项。
    自动生成的guide支持通过AutoDelta学习最大后验估计(MAP),或通过AutoGuideList组织不同的guide。

    from pyro.infer.autoguide import AutoDiagonalNormal
    
    model = BayesianRegression(3, 1)
    guide = AutoDiagonalNormal(model)
    

    最优化ELBO

    我们使用变分推断(SVI)来优化模型。正如非贝叶斯框架的线性回归模型那样,这里我们在每个训练的步骤中计算梯度并回传,只是目标函数从均方误差(MSE)换成证据下限(ELBO)。损失的形式为SVI类中的Trace_ELBO函数。

    from pyro import optim
    from pyro.infer import SVI, Trace_ELBO
    
    adam = optim.Adam({'lr':0.03})
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    

    我们使用的优化器是Adam,出于Pyro的优化库,而不是上面的torch.optimAdam封装了torch.optim.Adam
    优化器pyro.optim优化Pyro定义的参数,注意我们并不需要明确指定哪些参数需要优化,guide函数会自动识别可学习的部分,由SVI类自动完成。只要调用SVI.step,参数就被优化一次。step的参数被自动地输入model()guide()。整体的优化过程如下:

    pyro.clear_param_store()
    for j in range(num_iterations):
        # 计算损失,梯度反传并更新
        loss = svi.step(x_data, y_data)
        if j % 100 == 0:
            print('[iteration %04d] loss: %.4f' % (j + 1, loss / len(data)))
    

    结果为:

    [iteration 0000] loss: 12.2027
    [iteration 0100] loss: 5.2126
    [iteration 0200] loss: 4.7754
    [iteration 0300] loss: 4.5059
    [iteration 0400] loss: 4.0690
    [iteration 0500] loss: 3.5674
    [iteration 0600] loss: 2.9755
    [iteration 0700] loss: 2.8650
    [iteration 0800] loss: 2.9133
    [iteration 0900] loss: 2.8924
    [iteration 1000] loss: 2.8472
    [iteration 1100] loss: 2.8776
    [iteration 1200] loss: 2.8298
    [iteration 1300] loss: 2.8727
    [iteration 1400] loss: 2.8828
    

    通过查看Pyro的参数容器,我们检查优化后的参数:

    guide.requires_grad_(False)
    for name, value in pyro.get_param_store().items():
        print(name, pyro.param(name))
    

    结果:

    AutoDiagonalNormal.loc Parameter containing:
    tensor([-2.2410, -1.6331, -0.1248,  0.2918,  9.0075])
    AutoDiagonalNormal.scale tensor([0.0617, 0.1391, 0.0426, 0.0822, 0.0795])
    

    正如上面的结果,除了点估计的结果外,我们得到了不确定性的数值估计AutoDiagonalNormal.scale。我们注意到Autograd将不同的变量打包放在一个张量中,loc和scale都具有长度为(5,)的参数,正如我们前面提到的那样。
    为了将隐变量的参数看的更清楚,我们可以使用AutoDiagonalNormal.quantiles方法来解析隐变量的情况,并将它们限制在一定范围内(如sigma必须在(0,10)的范围内)。我们看下面的结果,其中位数(即quantile==0.5)很接近我们在最大似然法中求得的值。

    guide.quantiles([0.25, 0.5, 0.75])
    
    {'sigma':
    [tensor(0.9257), tensor(0.9613), tensor(0.9981)], 
    'linear.weight': 
    [tensor([[-1.7269, -0.1536,  0.2363]]), 
    tensor([[-1.6331, -0.1248,  0.2918]]), 
    tensor([[-1.5392, -0.0961,  0.3473]])], 
    'linear.bias': 
    [tensor([8.9539]), tensor([9.0075]), tensor([9.0612])]}
    

    模型评估

    我们从产生数据的模型采样预测数据,检验后验分布是否准确。这里我们用到Predictive
    工具类。步骤如下:

    • 从训好的模型中采样800个样本点。具体来说,从guide中采样非观察数据,而后将数据输入model,前向传播,我们就得到了样本点。(具体用法参考Model Serving部分的Predictive类)
    • return_sites中我们制定输出obs,也指定model的返回值"_RETURN"来获得回归的直线。除此之外,我们还返回linear.weight回归系数来做进一步的分析。
    • 剩下的代码用来绘制90%信心的图像。
    from pyro.infer import Predictive
    
    def summary(samples):
        site_stats = {}
        for k, v in samples.items():
            site_stats[k] = {
                'mean': torch.mean(v, 0),
                'std':torch.std(v, 0),
                '5%':v.kthvalue(int(len(v) * 0.05), dim=0)[0],
                '95%':v.kthvalue(int(len(v) * 0.95), dim=0)[0],
             }
        return site_stats
    
    predictive = Predictive(model, guide=guide, num_samples=800,
                            return_sites=('linear.weight', 'obs', '_RETURN'))
    samples = predictive(x_data)
    pred_summary = summary(samples)
    
    mu = pred_summary['_RETURN']
    y = pred_summary['obs']
    predictions = pd.DataFrame({
        'cont_africa': x_data[:, 0],
        'rugged': x_data[:, 1],
        'mu_mean': mu['mean'],
        'mu_prec_5': mu['5%'],
        'mu_prec_95': mu['95%'],
        'y_mean': y['mean'],
        'y_prec_5': y['5%'],
        'y_prec_95': y['95%'],
        'true_gdp': y_data,
    })
    
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
    african_nations = predictions[predictions['cont_africa'] == 1]
    non_african_nations = predictions[predictions['cont_africa'] == 0]
    

    上图展示了回归曲线,和90%信心的可能区域。然而,大多数点偏离了这一区域,这是由于我们忽略了sigma的影响,仅仅考虑了均值。下面的代码来绘制全部musigma的影响。
    fig, ax = plt.subplot(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
    fig.suptitle('Posterior predictive distribution with 90% CI', fontsize=6)
    ax[0].plot(non_african_nations['rugged'], non_african_nations['y_mean'])
    ax[0].fill_between(non_african_nations['rugged'], non_african_nations['y_perc_5'],
                       non_african_nations['y_prec_95'])
    ax[0].plot(non_african_nations["rugged"],
               non_african_nations["true_gdp"],
               "o")
    ax[0].set(xlabel="Terrain Ruggedness Index",
              ylabel="log GDP (2000)",
              title="Non African Nations")
    idx = np.argsort(african_nations["rugged"])
    
    ax[1].plot(african_nations["rugged"],
               african_nations["y_mean"])
    ax[1].fill_between(african_nations["rugged"],
                       african_nations["y_perc_5"],
                       african_nations["y_perc_95"],
                       alpha=0.5)
    ax[1].plot(african_nations["rugged"],
               african_nations["true_gdp"],
               "o")
    ax[1].set(xlabel="Terrain Ruggedness Index",
              ylabel="log GDP (2000)",
              title="African Nations");
    

    我们观察到大多数点都在范围内。观察后验预测的情况,可以检查我们的模型是否准确。
    最后,我们给出预测的可靠性。由下图可知,对于非洲国家而言,国土崎岖程度的峰值为正,对非洲以外国家则是负。这进一步核实了最初的猜测。

    weight = samples['linear.weight']
    weight = weight.reshape(weight.shape[0], 3)
    # weight 的前两维表示 'cont_africa', 'rugged'
    gamma_with_africa = weight[:, 1] + weight[:, 2]
    gamma_outside_africa = weight[:, 1]
    #
    fig = plt.figure(figsize=(12, 6))
    sns.distplot(gamma_with_africa, kde_kws={'label': 'African nations'},)
    sns.distplot(gamma_with_africa, kde_kws={'label': 'Non_African nations'})
    fig.suptitle('Density of Slope : log(GDP) vs. Terrain Ruggedness');
    

    通过TorchScript建立模型

    上面的guidemodelPredictivetorch.nn.Module的实例,它们可以被TorchScript串行。
    我们重写Predictive类,使用

    • trace来捕捉代码执行的“轨迹”;
    • replay在采样guide时实现条件概率。
    from collections import defaultdict
    from pyro import poutine
    from pyro.poutine.util import prune_subsample_sites
    import warning
    
    class Predict(torch.nn.Module):
        def __init__(self, model, guide):
            super().__init__()
            self.model = model
            self.guide = guide
        #
        def forward(self, *args, **kwargs):
            samples = {}
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace(*args, **kwargs)
            for site in prune_subsample_sites(model_trace).stochastic_nodes:
                samples[site] = model_trace.nodes[site]['value']
            return tuple(v for _, v in sorted(samples.items()))
    
    predict_fn = Predict(model, guide)
    predict_module = torch.jit.trace_module(predict_fn, {'forward': (x_data,)}, check_trace=False)
    print(predict_module(x_data))
    

    结果

    (tensor([9.0981]),
    tensor([[-1.7491, -0.0743,  0.3191]]), 
    tensor([ 6.5791, 11.1204,  9.2202,  9.6960,  8.8708,  8.9990,  9.1509,  8.5572,
             8.1153,  7.8486,  8.1661,  6.4192,  8.9866,  8.2883,  8.0250,  8.0258,
             8.7247, 10.0411,  9.6446,  9.8050,  7.8771, 10.9617,  9.8290,  7.6839,
             9.3751,  8.9871, 10.6739,  7.9197,  7.9699,  7.7173,  6.7463,  7.0445,
             7.5290,  8.7151,  8.3060,  6.4522,  8.6809,  8.9207,  7.9165,  8.3513,
             8.2212,  8.0382,  9.4009,  9.5989,  7.9512,  6.3474,  8.0598,  6.6843,
             9.7230,  8.1434,  9.5843,  8.6182,  9.0642,  7.8121,  9.0719,  9.3829,
             9.0696,  9.0074,  8.6261,  8.5372,  7.2110,  8.6863,  9.7784,  9.6941,
             7.8893,  9.7475, 10.0347,  8.3116,  8.5901,  9.4015,  7.7256,  8.2334,
             9.9464, 10.2544,  7.1801,  8.4415,  9.0320,  9.0331,  8.1440,  8.1207,
             9.3618,  8.4149,  6.2104,  7.8214,  6.7098,  8.5715,  9.0018,  9.4816,
             8.0279,  8.6050,  9.1565,  9.0454,  9.6953,  9.3202,  7.4633,  6.8954,
             7.8622,  7.3698,  9.5622,  8.0203,  9.4199,  9.3618,  8.0169,  9.9572,
             8.0362,  8.1701,  6.6665,  6.1006,  7.5254,  8.2853,  7.9838, 10.2436,
             6.7060,  8.8127,  8.6282,  8.7436, 10.3548,  7.9303, 10.2233,  9.2562,
             7.2063,  9.4613, 10.0881,  9.6515, 10.4115, 10.2262,  7.5288,  8.5231,
             9.3029,  7.3952,  8.5516,  9.9590,  8.6210,  9.0980,  7.1726,  6.5949,
             9.7424,  8.3487,  7.1420, 11.2635,  8.5057,  8.6983,  8.5284,  7.5776,
             9.5280, 10.0641,  6.7022,  6.1189,  7.6534, 10.3216,  9.7166, 10.7897,
             8.3233,  8.1185,  8.2430,  7.1593,  6.6486,  9.8121,  9.0420,  7.3089,
             8.3944,  9.9636,  9.3352,  8.8867,  7.8940,  9.0260,  9.5576,  8.3397,
             8.4302,  6.4737]),
    tensor(0.9356))
    

    使用torch.jit.save可以保存模型,Pytorch的C++API可以重载模型,也可以使用Python的API重载。

    torch.jit.save(predict_module, 'tmp/reg_predict.pt')
    pred_load = torch.jit.load('tmp/reg_predict.pt')
    pred_load(x_data)
    

    我们检查Predict模型是否串行成功,产生出一定量的采样并重复前面的绘图。

    weight = []
    for _ in range(800):
        # index = 1 对应于linear.weight
        weight.append(predict_module(x_data)[1])
    
    weight = torch.stack(weight).detach()
    weight = weight.reshape(weight.shape[0], 3)
    gamma_within_africa = weight[:, 1] + weight[:, 2]
    gamma_outside_africa = weight[:, 1]
    fig = plt.figure(figsize = (12, 6))
    sns.distplot(gamma_within_africa, kde_kws={'label': 'African nations'},)
    sns.distplot(gamma_outside_africa, kde_kws={'label': 'Non_African nations'})
    fig.suptitle('TorchScript Module : log(GDP) vs. Terrain Ruggedness');
    

    相关文章

      网友评论

        本文标题:Pyro进阶(二):贝叶斯回归问题 一

        本文链接:https://www.haomeiwen.com/subject/edovwctx.html