Pyro简介:产生式模型实现库(六),Pyro的张量尺寸

作者: WilliamY | 来源:发表于2019-12-03 12:15 被阅读0次

    太长不看版

    • 模型在学习或调试过程中,设置pyro.enable_validation(True)
    • 张量的“广播”,维度对齐自右向左:torch.ones(3,4,5) + torch.ones(5)
    • 分布的尺寸 .sample().shape == batch_shape + event_shape
    • 分布的尺寸 .log_prob(x).shape == batch_shape(没有event_shape);
    • 使用expand()从Pyro中采样一批数据,或使用plate机制自动扩展;
    • 使用my_dist.to_event(1)声明维度为依赖(dependent),或说不独立;
    • 使用with pyro.plate('name', size):声明条件独立;
    • 所有维度要么是依赖的,要么是条件独立的;
    • 支持维度最左方的批处理,启动Pyro的并行处理;
      • 使用负号指标,如x.sum(-1),而不是x.sum(2)
      • 使用省略号,如pixel = image[...,i, j]
      • 如果要枚举i,j,使用Vindex,如pixel = Vindex(image)[...,i, j]

    内容列表

    • 概率分布的形状
    • plate声明条件独立
    • 在plate中部分采样
    • 并行地枚举,张量的广播

    文件头如下

    import os
    import torch
    import pyro
    from torch.distributions import constraints
    from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
    from pyro.distributions.util import broadcast_shape
    from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
    import pyro.poutine as poutine
    from pyro.optim import Adam
    
    smoke_test = ('CI' in os.environ)
    pyro.enable_validation(True) #这句话最好加上
    
    # 我们借助这个函数,检查模型是否正确
    def test_model(model, guide, loss):
        pyro.clear_param_store()
        loss.loss(model, guide)
    

    概率分布的尺寸:batch_shapeevent_shape

    Pytorch的张量Tensor只有一个尺寸.shape,但是Distributions有两个尺寸.batch_shape.event_shape,分别表示条件独立的随机变量的大小和不独立的随机变量的大小。这两部分构成了一个样本的尺寸。

    x = d.sample()
    assert x.shape == d.batch_shape + d.event_shape
    

    由于计算对数似然只牵涉不独立的变量,所以.log_prob()方法后,event_shape就被缩并了,只剩下batch_shape

    assert d.log_prob(x) == d.batch_shape
    

    Distributions.sample()方法可以输入一个参数sample_shape,作为独立同分布(iid)的随机变量,所以指定样本大小的采样,具有三个尺寸。

    x2 = d.sample(sample_shape)
    assert x2.shape == sample_shape + batch_shape + event_shape
    

    总结来说

          |      iid     | independent | dependent
    ------+--------------+-------------+------------
    shape = sample_shape + batch_shape + event_shape
    

    由上可推论,单变量随机分布的event_shape为0,因为每次采样值是一个实数,所以没有不独立的维度。像MultivariateNormal多元高斯分布这样的概率分布,具有len(event_shape) == 1,因为每个采样是一个向量,向量内部是彼此依赖的(这里假定方差矩阵不是对角阵)。而InverseWishart逆威沙特分布具有len(event_shape) == 2,等等。

    关于概率分布尺寸的举例

    从单变量随机分布开始。

    d = Bernoulli(0.5)
    assert d.batch_shape == ()
    assert d.event_shape == ()
    x = d.sample()
    # x是一个Pytorch张量,没有batch_shape和event_shape
    assert x.shape == () 
    assert d.log_prob(x).shape == ()
    

    通过传入批参数,概率分布数据可以分成批。

    d = Bernoulli(0.5 * torch.ones(3, 4))
    assert d.batch_shape == (3,4)
    assert d.event_shape == ()
    x = d.sample()
    assert x.shape == (3, 4)
    assert d.log_prob(x).shape == (3, 4)
    

    另一种成批的方法,是通过expand()。不过只在参数的最左侧维度独立时才可使用。

    d = Bernoulli(torch.tensor([.1, .2, .3, .4])).expand([3, 4])
    # 注意expand的参数写在一个列表中
    assert d.batch_shape == (3, 4)
    assert d.event_shape == ()
    x = d.sample()
    assert x.shape == (3, 4)
    assert d.log_prob(x).shape == (3, 4)
    

    多元高斯分布具有非空的event_shape维度。对于这些分布来说,.sample().log_prob()的维度是不同的。

    d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
    assert d.batch_shape == ()
    assert d.event_shape == (3, )
    x = d.sample()
    assert x.shape == (3, ) # == batch_shape + event_shape
    assert d.log_prob(x).shape == () # == batch_shape
    

    改变分布的维度独立性

    使用关键字.to_event(n)改变不独立维度的情况,其中n表示从数第n维度开始,声明为不独立维度。

    d = Bernoulli(0.5 * torch.ones(3, 4)).to_event(1)
    assert d.batch_shape == (3, )
    assert d.event_shape == (4, )
    x = d.sample()
    assert x.shape == (3, 4)
    assert d.log_prob(x).shape == (3, )
    

    用户必须小心地设置.to_event(n)batch_shape缩减到合适的水平上,或者用pyro.plate声明维度的独立性。采样仍旧会保留batch_shape+event_shape的尺寸,然而log_prob(x)只剩下batch_shape

    声明为不独立,通常是安全的做法

    在Pyro中,我们常常会声明维度是不独立的,哪怕它们实际上是独立的。请看这个例子:

    x = pyro.sample('x', dist.Normal(0, 1).expand([10]).to_event(1))
    assert x.shape == (10,)
    

    上面的例子很容易就可以换成MultivariateNormal分布。它将下面的写法简化了:

    with pyro.plate('x_plate', 10):
        x = pyro.sample('x', dist.Normal(0, 1)) #不需要expand,系统自动补全
        assert x.shape == (10,)
    

    实际上,这两份代码存在一点小小的差别。上面的代码中,Pyro默认x之间是不独立的,而下面的x则是条件独立的。声明为不独立通常是安全的,这与图论中的d-separation基于同一个原理:在不同节点之间连一条边,即便节点之间不存在互相依赖关系,随着优化该边的权重将越来越低,并不影响最终结果;而本就存在依赖的节点连了一条边,任优化策略多么高明,都无法弥补这一错误。这种错误常见于平均场假设的模型中。不过,在实际执行时,Pyro的SVI模块在估算Normal分布时,两份代码的梯度估计值是一样的。

    通过plate声明维度为独立

    Pyro的上下文管理器pyro.plate能够声明特定的维度为独立维度。推断算法可以利用这一独立性做一些算法优化,例如构造低方差的梯度估计器,再如求解推断问题不在指数空间而在线性空间采样。下面的例子中,我们将声明同一批次中的数据之间是互相独立的。
    最简单的方法,是不声明独立维度,系统将缺省值-1——即最右边的维度,作为独立维度。

    with pyro.plate('my_plate'):
        # 在该上下文中,维度-1将作为独立维度
    

    虽然效果是一样的,不过我们仍提倡用户写出来,以帮助用户调试代码:

    with pyro.plate('my_plate', len(data)):
        #  在该上下文中,维度-1将作为独立维度
    

    从Pyro 0.2版本开始,plate语句可以嵌套使用。比如声明图像的每个像素都是独立的:

    with pyro.plate('x_axis', 320):
        #  在该上下文中,维度-1将作为独立维度
        with pyro.plate('y_axis', 200):
            #  在该上下文中,维度-2和-1将作为独立维度
    

    我们习惯上总从右向左声明独立维度,所以指标是负的,如-1,-2,等等。
    有时情况会更复杂一些,比如我们希望声明一些噪声依赖x,另一些噪声依赖y,还有一些噪声依赖二者。这时Pyro允许用户声明多重独立,为了清楚地标明独立维度,必须指定dim这一参数,如下面的例子:

    x_axis = pyro.plate('x_axis', dim = -2)
    y_axis = pyro.plate('y_axis', dim = -3)
    with x_axis:
        #  在该上下文中,维度-2将作为独立维度
    with y_axis:
        #  在该上下文中,维度-3将作为独立维度
    with x_axis, y_axis:
        #  在该上下文中,维度-2和-3将作为独立维度
    

    让我们举更多例子,来展示plate的用法。

    def model1():
        a = pyro.sample('a', Normal(0, 1))
        b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
        with pyro.plate('c_plate', 2):
            c = pyro.sample('c', Normal(torch.zeros(2), 1))
        with pyro.plate('d_plate', 3):
            d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
        assert a.shape == ()                  # batch_shape == (), event_shape == ()
        assert b.shape == (2,)                # batch_shape == (), event_shape == (2,)
        assert c.shape == (2,)                # batch_shape == (2,), event_shape == ()
        assert d.shape == (3, 4, 5)           # batch_shape == (3), event_shape == (4, 5)
        ##
        x_axis = pyro.plate('x_axis', 3, dim=-2)
        y_axis = pyro.plate('y_axis', 2, dim=-3)
        with x_axis:
            x = pyro.sample('x', Normal(0, 1))
        with y_axis:
            y = pyro.sample('y', Normal(0, 1))
        with x_axis, y_axis:
            xy = pyro.sample('xy', Normal(0, 1))
            z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
        assert x.shape == (3, 1)               # batch_shape == (3, 1), event_shape==()
        assert y.shape == (2, 1, 1)            # batch_shape == (2, 1, 1), event_shape==()
        assert xy.shape == (2, 3, 1)           # batch_shape == (2, 3, 1), event_shape==()
        assert z.shape == (2, 3, 1, 5)         # batch_shape == (2, 3, 1), event_shape==(5,)
    
    test_model(model1, model1, Trace_ELBO())
    

    可视化如下:

    batch dims | event dims
    -----------+-----------
               |        a = sample("a", Normal(0, 1))
               |2       b = sample("b", Normal(zeros(2), 1)
               |                        .to_event(1))
               |        with plate("c", 2):
              2|            c = sample("c", Normal(zeros(2), 1))
               |        with plate("d", 3):
              3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
               |                       .to_event(2))
               |
               |        x_axis = plate("x", 3, dim=-2)
               |        y_axis = plate("y", 2, dim=-3)
               |        with x_axis:
            3 1|            x = sample("x", Normal(0, 1))
               |        with y_axis:
          2 1 1|            y = sample("y", Normal(0, 1))
               |        with x_axis, y_axis:
          2 3 1|            xy = sample("xy", Normal(0, 1))
          2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
               |                       .to_event(1))
    

    为了在调试代码时方便地查看随机变量的形状,Pyro提供了Trace.format_shapes()
    方法,在采样点上打印分布的形状(包含site['fn'].batch_shapesite['fn'].event_shape)、变量的形状(site['value'].shape)、如果计算对数似然概率时log_prob的形状(site['log_prob'].shape)。

    trace = poutine.trace(model1).get_trace()
    trace.compute_log_prob()  #  可选的,这句话可以打印log_prob的形状
    print(trace.format_shapes())
    

    打印结果:

    Trace Shapes:
     Param Sites:
    Sample Sites:
           a dist       |
            value       |
         log_prob       |
           b dist       | 2
            value       | 2
         log_prob       |
     c_plate dist       |
            value     2 |
         log_prob       |
           c dist     2 |
            value     2 |
         log_prob     2 |
     d_plate dist       |
            value     3 |
         log_prob       |
           d dist     3 | 4 5
            value     3 | 4 5
         log_prob     3 |
      x_axis dist       |
            value     3 |
         log_prob       |
      y_axis dist       |
            value     2 |
         log_prob       |
           x dist   3 1 |
            value   3 1 |
         log_prob   3 1 |
           y dist 2 1 1 |
            value 2 1 1 |
         log_prob 2 1 1 |
          xy dist 2 3 1 |
            value 2 3 1 |
         log_prob 2 3 1 |
           z dist 2 3 1 | 5
            value 2 3 1 | 5
         log_prob 2 3 1 |
    

    plate句块中采样部分张量

    plate最重要的功能之一就是部分采样,plate句块中的随机变量都是条件独立的。如果样本量为总样本的一半,那么样本损失的值将被认为是总损失的一半。
    在实现部分时,用户需要通知Pyro采样量和样本总量的值,Pyro就会随机产生一定量的数据指标作为样本。

    data = torch.arange(100.)
    
    def model2():
        mean = pyro.param('mean', torch.zeros(len(data)))
        with pyro.plate('data', len(data), subsample_size=10) as ind: 
            assert len(ind) == 10
            batch = data[ind]
            mean_batch = mean[ind]
            # 在batch中做一些计算
            x = pyro.sample('x', Normal(mean_batch, 1), obs=batch)
            assert x.shape == (10,)
    
    test_model(model2, guide=lambda: None, loss=Trace_ELBO())
    

    广播功能,实现数据的并行枚举

    Pyro 0.2后的版本都支持离散随机变量的并行枚举功能。这一功能可以极大地减少计算变分推断时梯度估计的方差,确保优化的稳定性。
    为了实现枚举,Pyro需要用户指定哪些维度是不独立的,哪些是独立的,只有不独立的维度才允许枚举。自然地,这一指定需要用到plate语句,我们需要声明最大数量的枚举范围,这一关键字为max_plate_nesting,它是SVI类的一个参数(而且通过TraceEnum_ELBO传入)。通常来说,Pyro可以自动地指定枚举范围(只要运行一次modelguide,系统将了解枚举范围),不过在动态变化的模型中,用户需要人工地指定max_plate_nesting的数值。
    为了弄清楚max_plate_nesting的作用机制,我们重新回顾model1(),这一次我们关心三种维度的形状:最左边的枚举维度,中间的批维度,最右边的不独立维度。而max_plate_nesting规定了中间的批维度

          max_plate_nesting = 3
               |<--->|
    enumeration|batch|event
    -----------+-----+-----
               |. . .|      a = sample("a", Normal(0, 1))
               |. . .|2     b = sample("b", Normal(zeros(2), 1)
               |     |                      .to_event(1))
               |     |      with plate("c", 2):
               |. . 2|          c = sample("c", Normal(zeros(2), 1))
               |     |      with plate("d", 3):
               |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
               |     |                     .to_event(2))
               |     |
               |     |      x_axis = plate("x", 3, dim=-2)
               |     |      y_axis = plate("y", 2, dim=-3)
               |     |      with x_axis:
               |. 3 1|          x = sample("x", Normal(0, 1))
               |     |      with y_axis:
               |2 1 1|          y = sample("y", Normal(0, 1))
               |     |      with x_axis, y_axis:
               |2 3 1|          xy = sample("xy", Normal(0, 1))
               |2 3 1|5         z = sample("z", Normal(0, 1).expand([5]))
               |     |                     .to_event(1))
    

    上面的例子中,如果我们声明(过度)充裕的max_plate_nesting=4也是可以的,但不能声明例如max_plate_nesting=2,因为2<3,这时系统将会报错。
    我们再举一个例子:

    @config_enumerate
    #该修饰符表示枚举类型,不能省略!!
    def model3():
        p = pyro.param('p', torch.arange(6) / 6.)
        locs = pyro.param('locs', torch.tensor([-1., 1.]))
        # locs in [-1, 1]
        # a in [0, 1, 2, 3, 4, 5]
        a = pyro.sample('a', Categorical(torch.ones(6) / 6.))
        # p[a] in [0, 1/6, 2/6, 3/6, 4/6, 5/6]
        b = pyro.sample('b', Bernoulli(p[a])) # 声明b依赖于a
        # b in [0, 1]
        with pyro.plate('c_plate', 4):
            c = pyro.sample('c',  Bernoulli(0.4))
            # c in [0, 1]
            with pyro.plate('d_plate', 5):
                d = pyro.sample('d', Bernoulli(0.3))
                # d in [0, 1]
                e_loc = locs[d.long()].unsqueeze(-1)
                # e_loc in [-1, 1]
                e_scale = torch.arange(1., 8.)
                # e_scale in [1, 2, ..., 7]
                e = pyro.sample('e', Normal(e_loc, e_scale).to_event(1)) # 依赖于d
        #                            枚举维度|批维度(独立维度)|不独立维度
        assert a.shape == (                6,            1,1            )  # 多类别分布的维度大小为6
        assert b.shape == (              2,1,            1,1            )  # 枚举伯努利分布,非扩增
        assert c.shape == (            2,1,1,            1,1            )  # 伯努利分布,非扩增
        assert d.shape == (          2,1,1,1,            1,1            )  # 伯努利分布,非扩增
        assert e.shape == (          2,1,1,1,            5,4,          7)  # e是采样出来的,依赖于d
        #
        assert e_loc.shape ==   (    2,1,1,1,            1,1,         1,) # 最后的逗号可以省略
        assert e_scale.shape == (                                     7,) # 注意逗号不能省略!!
    
    test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
    

    我们重新来可视化一下:

         max_plate_nesting = 2
                |<->|
    enumeration batch event
    ------------|---|-----
               6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
             2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
                |   |     with pyro.plate("c_plate", 4):
           2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
                |   |         with pyro.plate("d_plate", 5):
         2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
         2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
                |   |7            e_scale = torch.arange(1., 8.)
         2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
                |   |                             .to_event(1))
    

    我们分析一下这些维度。我们为Pyro指定了枚举的维度max_plate_nesting:Pyro给a赋予枚举维度-3,给b赋予枚举维度-4,给c赋予枚举维度-5,给d赋予枚举维度-6。当用户不指定维度扩展后的数值时,新维度被默认为1,这方便计算。我们还可以观察到,log_prob的形状广播的范围是枚举维度和独立维度,比如trace.nodes['d']['log_prob'].shape == (2,1,1,1,5,4)

    使用Pyro的自带工具Trace.format_shapes():

    trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
    trace.compute_log_prob() # 可选
    print(trace.format_shapes())
    

    结果:

    Trace Shapes:                
     Param Sites:                
                p             6  
             locs             2  
    Sample Sites:                
           a dist             |  
            value       6 1 1 |  
         log_prob       6 1 1 |  
           b dist       6 1 1 |  
            value     2 1 1 1 |  
         log_prob     2 6 1 1 |  
     c_plate dist             |  
            value           4 |  
         log_prob             |  
           c dist           4 |  
            value   2 1 1 1 1 |  
         log_prob   2 1 1 1 4 |  
     d_plate dist             |  
            value           5 |  
         log_prob             |  
           d dist         5 4 |  
            value 2 1 1 1 1 1 |  
         log_prob 2 1 1 1 5 4 |  
           e dist 2 1 1 1 5 4 | 7
            value 2 1 1 1 5 4 | 7
         log_prob 2 1 1 1 5 4 |  
    

    编写并行代码

    在Pyro中,我们需要掌握两个取巧的技术,来实现并行采样:广播椭圆分片。我们通过下面的例子来分别介绍枚举情形和非枚举情形下的用法。

    width = 8
    height = 10
    sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
    enumeration = None # 设为True或False
    
    def fun(observe):
        p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
        p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
        x_axis = pyro.plate('x_axis', width, dim=-2)
        y_axis = pyro.plate('y_axis', height, dim=-1)
        # 在这些样本点上,分布形状取决于Pyro是否枚举
        with x_axis:
            x_active = pyro.sample('x_active', Bernoulli(p_x))
        with y_axis:
            y_active = pyro.sample('y_active', Bernoulli(p_y))
        if enumerated:
            assert x_active.shape == (2, 1, 1) # max_plate_nesting==2
            assert y_active.shape == (2, 1, 1, 1)
        else:
            assert x_active.shape == (width, 1)
            assert y_active.shape == (height, )
        # 第一个trick:广播,broadcast。枚举和非枚举都可使用。
        p = 0.1 + 0.5 * x_active * y_active
        if enumerated:
            assert p.shape == (2, 2, 1, 1)
        else:
            assert p.shape == (width, height)
        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
        # 第二个trick:椭圆分片。Pyro可以在左方任意增加维度。
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        if enumerated:
            assert dense_pixels.shape == (2, 2, width, height)
        else:
            assert dense_pixels.shape == (width, height)
        #
        with x_axis, y_axis:
            if observe:
                pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
    
    def model4():
        fun(observe=True)
    
    def guide4():
        fun(observe=False)
    
    # Test: 非枚举
    enumerated = False
    test_model(model4, guide4, Trace_ELBO())
    
    # Test: 枚举。注意目标函数为TraceEnum_ELBO
    enumerated = True
    test_model(model4, config_enumerate(guide4, 'parallel'), TraceEnum_ELBO(max_plate_nesting=2))
    

    在pyro.plate内部实现自动广播

    在以上所有model/plate的实现中,我们都使用了pyro.plate的自动扩增功能,使变量满足pyro.sample规定的形状。这一广播方式等价于.expand()
    我们稍许更改上面的代码作为例子,注意几点区别:

    • 我们仅考虑并行枚举的情况,但对于串行的、非枚举的情况也适用;
    • 我们将采样函数分离出来,model代码使用常规的形式,这样做有利于代码的维护;
    • pyro.plate使用ELBO的num_particles参数,将上下文中最远的内容打包。
    # 规定采样的样本量
    num_particals = 100
    width = 8
    height = 10
    sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
    
    def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
        with x_axis:
            x_active = pyro.sample('x_active', Bernoulli(p_x).expand([num_particals, width, 1]))
        with y_axis:
            y_active = pyro.sample('y_active', Bernoulli(p_y).expand([num_particals, 1, height]))
        return x_active, y_active
    
    def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
        with x_axis:
            x_active = pyro.sample('x_active', Bernoulli(p_x))
        with y_axis:
            y_active = pyro.sample('y_acitve', Bernoulli(p_y))
        return x_active, y_active
    
    def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
        with x_axis:
            x_active = pyro.sample('x_active', Bernoulli(p_x).expand([width, 1]))
        with y_axis:
            y_active = pyro.sample('y_active', Bernoulli(p_y).expand([height]))
        return x_acitve, y_active
    
    def fun(observe, sample_fn):
        p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
        p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
        x_axis = pyro.plate('x_axis', width, dim=-2)
        y_axis = pyro.plate('y_axis', height, dim=-1)
        # 
        with pyro.plate('num_particals', 100, dim=-3):
            x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
            ## 并行枚举指标被扩增在“num_particals”最左边
            assert x_active.shape == (2, 1, 1, 1) 
            assert y_active.shape == (2, 1, 1, 1, 1)
            p = 0.1 + 0.5 * x_active * y_active
            assert p.shape == (2, 2, 1, 1, 1)
            dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
            for x, y in sparse_pixels:
                dense_pixels[..., x, y] = 1
            assert dense_pixels.shape == (2, 2, 1, width, height)
            #
            with x_axis, y_axis:
                if observe:
                    pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
    
    def test_model_with_sample_fn(sample_fn):
        def model():
            fun(observe=True, sample_fn=sample_fn)
        #
        @config_enumerate
        def guide():
            fun(observe=False, sample_fn=sample_fn)
    
    test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
    test_model_with_sample_fn(sample_pixel_locations_full_broadcasting) 
    test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)
    

    在第一个采样函数中,我们像账房先生那样,仔细规定了Bernoulli分布的的形状。请仔细观察num_particles, widthheight传入sample_pixel_locations函数的方式。这一方式有些笨拙。
    对于第二个采样函数,我们需要注意pyro.plate的参数必须要提供,这样系统才能猜出批维度的形状。
    我们可以看到,对于张量操作,使用pyro.plate实现并行是多么容易!
    pyro.plate还具有将代码模块化的效果。

    相关文章

      网友评论

        本文标题:Pyro简介:产生式模型实现库(六),Pyro的张量尺寸

        本文链接:https://www.haomeiwen.com/subject/ctkpwctx.html