美文网首页
pytorch中处理变长序列

pytorch中处理变长序列

作者: 全村希望gone | 来源:发表于2019-12-30 22:55 被阅读0次
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    from torch.nn import utils as nn_utils
    
    batch_size = 3
    max_length = 3
    hidden_size = 5
    n_layers = 1
    
    tensor_in = torch.FloatTensor([[1, 2, 3], [4, 5, 0], [1, 0, 0]]).resize(3, 3, 1)
    tensor_in = Variable(tensor_in)  # [batch, seq, feature], [2, 3, 1]
    seq_lengths = [3, 2, 1]  # list of integers holding information about the batch size at each sequence step
    
    # pack it
    # 这个函数的作用是生成未经过padding的值
    pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
    print('000', pack)
    # initialize
    rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
    h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
    
    # forward
    # 生成RNN的输出
    out, _ = rnn(pack, h0)
    print('out', out)
    # unpack
    # 生成整个序列(包括padding值)的embedding
    unpacked = nn_utils.rnn.pad_packed_sequence(out)
    print('111', unpacked)
    

    输出结果

    E:\python\python.exe G:/zlx/github-projects/ChineseNER-master-v20/testProgram.py
    E:\python\lib\site-packages\torch\tensor.py:287: UserWarning: non-inplace resize is deprecated
      warnings.warn("non-inplace resize is deprecated")
    000 PackedSequence(data=tensor([[1.],
            [4.],
            [1.],
            [2.],
            [5.],
            [3.]]), batch_sizes=tensor([3, 2, 1]))
    out PackedSequence(data=tensor([[ 0.7497,  0.3511, -0.9212, -0.1683,  0.4390],
            [ 0.0385, -0.1618, -0.9935,  0.5341, -0.5963],
            [ 0.7259,  0.3769, -0.9672, -0.4545, -0.0359],
            [-0.0717,  0.6466, -0.9100,  0.7203, -0.4931],
            [-0.3446,  0.2835, -0.9960,  0.6210, -0.8128],
            [-0.4544,  0.6179, -0.9786,  0.5303, -0.8303]], grad_fn=<CatBackward>), batch_sizes=tensor([3, 2, 1]))
    111 (tensor([[[ 0.7497,  0.3511, -0.9212, -0.1683,  0.4390],
             [ 0.0385, -0.1618, -0.9935,  0.5341, -0.5963],
             [ 0.7259,  0.3769, -0.9672, -0.4545, -0.0359]],
    
            [[-0.0717,  0.6466, -0.9100,  0.7203, -0.4931],
             [-0.3446,  0.2835, -0.9960,  0.6210, -0.8128],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
    
            [[-0.4544,  0.6179, -0.9786,  0.5303, -0.8303],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<CopySlices>), tensor([3, 2, 1]))
    
    Process finished with exit code 0
    

    总结:RNN的输入维度是与batch的size的最后一个维度相等的。然后RNN将不为0的值(不是padding的元素)转为维度是hidden_size的隐向量。将这些值传入pad_packed_sequence后得到了另一些隐向量,但是这些隐向量已经不是按照之前的batch排列的了,而是每个batch中对应的行生成新的batch。为什么要这么干呢,不同的batch之间数据有什么联系吗?(待解决)

    图1 图2

    -----------------------------------补充:重要!!!------------------------------------
    上面的问题引出了“LSTM如何处理batch数据”,弄懂了这个问题之后,我才明白我之前好像压根都没想过或者想错了这个过程。这大概是今天最大的收获。
    直接看例子吧(来源知乎),LSTM如何处理下图中的数据呢?

    image.png 是按照“床前明月光”、“白毛浮绿水”这种整体输进去处理,还是按照“床白黄松随红锄”这种进行处理呢?答案是第二种。如果按照第一种的话,那就和没有用batch是一样的了,那batch的存在就毫无意义了。第二种的话也正好解释了上面的疑问,为什么原本属于第一个batch中的三个值变成向量后会跑到不同的batch中,其实应该说原属于不同batch中的三个值跑到了一个batch中,这就是LSTM处理batch数据的机制。每个batch时间步是相同的(不同的话会在输入LSTM之前padding成相同的),所以时刻 t 处理不同batch中的数据,下一时刻继续处理不同batch中 t+1 时刻的数据。对应图中就是 t=1 时处理“床白黄松随红锄”,t=2 时处理“前毛河下风豆禾”。
    用数据表示,batch_size=7,seq_len=5,假设embed_size=3,lstm_hidden_size=10,我们来看下面的代码和输出结果

    代码

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    from torch.nn import utils as nn_utils
    
    batch_size = 7
    # max_length = 3
    hidden_size = 8
    n_layers = 1
    
    tensor_in = torch.FloatTensor(
        [[[1, 1, 1], [1, 2, 0], [1, 0, 0], [5, 0, 0], [6, 0, 0]],
         [[3, 2, 5], [1, 2, 0], [4, 0, 0], [5, 0, 0], [6, 0, 0]],
         [[7, 6, 8], [3, 4, 0], [1, 0, 0], [9, 0, 0], [1, 0, 0]],
         [[1, 1, 1], [1, 2, 0], [1, 0, 0], [5, 0, 0], [6, 0, 0]],
         [[3, 2, 5], [1, 2, 0], [4, 0, 0], [5, 0, 0], [6, 0, 0]],
         [[7, 6, 8], [3, 4, 0], [1, 0, 0], [9, 0, 0], [1, 0, 0]],
         [[7, 6, 8], [3, 4, 0], [1, 0, 0], [9, 0, 0], [1, 0, 0]]])
    tensor_in = Variable(tensor_in)  # [batch, seq, feature], [2, 3, 1]
    seq_lengths = [5,5,5,5,5,5,5]  # list of integers holding information about the batch size at each sequence step
    
    # pack it
    # 这个函数的作用是生成未经过padding的值
    pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
    print('000', pack)
    # initialize
    rnn = nn.RNN(3, hidden_size, n_layers, batch_first=True)
    h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
    
    # forward
    # 生成RNN的输出
    out, _ = rnn(pack, h0)
    print('out', out)
    # unpack
    # 生成整个序列(包括padding值)的embedding
    unpacked = nn_utils.rnn.pad_packed_sequence(out)
    print('111', unpacked)
    

    输出结果

    E:\python\python.exe G:/zlx/github-projects/ChineseNER-master-v20/testProgram.py
    000 PackedSequence(data=tensor([[1., 1., 1.],
            [3., 2., 5.],
            [7., 6., 8.],
            [1., 1., 1.],
            [3., 2., 5.],
            [7., 6., 8.],
            [7., 6., 8.],
            [1., 2., 0.],
            [1., 2., 0.],
            [3., 4., 0.],
            [1., 2., 0.],
            [1., 2., 0.],
            [3., 4., 0.],
            [3., 4., 0.],
            [1., 0., 0.],
            [4., 0., 0.],
            [1., 0., 0.],
            [1., 0., 0.],
            [4., 0., 0.],
            [1., 0., 0.],
            [1., 0., 0.],
            [5., 0., 0.],
            [5., 0., 0.],
            [9., 0., 0.],
            [5., 0., 0.],
            [5., 0., 0.],
            [9., 0., 0.],
            [9., 0., 0.],
            [6., 0., 0.],
            [6., 0., 0.],
            [1., 0., 0.],
            [6., 0., 0.],
            [6., 0., 0.],
            [1., 0., 0.],
            [1., 0., 0.]]), batch_sizes=tensor([7, 7, 7, 7, 7]))
    out PackedSequence(data=tensor([[-0.1825,  0.9101, -0.8702, -0.6278,  0.2190,  0.2605, -0.7234, -0.0163],
            [ 0.5106,  0.9338, -0.8618,  0.2582,  0.9622,  0.6992,  0.8183,  0.8412],
            [ 0.9439,  0.7868, -0.9984, -0.9804,  0.9951,  0.8878,  0.9840,  0.9986],
            [ 0.8220,  0.7124,  0.6617,  0.3772, -0.0475, -0.8009,  0.1510, -0.0459],
            [ 0.7401,  0.9499, -0.9404, -0.3977,  0.7578,  0.5945,  0.6236,  0.9589],
            [ 0.8393,  0.9940, -0.9980, -0.9685,  0.9733,  0.8201,  0.8973,  0.9990],
            [ 0.9496,  0.9072, -0.9972, -0.9506,  0.9938,  0.8273,  0.9807,  0.9951],
            [ 0.3781,  0.5988, -0.4116, -0.4798,  0.3336, -0.1607, -0.5875, -0.2394],
            [ 0.6116, -0.2055, -0.4288, -0.8675,  0.8222,  0.4301, -0.4239,  0.2674],
            [ 0.8690, -0.4402, -0.6296, -0.9601,  0.9292,  0.3328,  0.2067,  0.3849],
            [ 0.3107, -0.2493, -0.2772, -0.3696,  0.5967, -0.1408, -0.3301, -0.6034],
            [ 0.5993, -0.1124, -0.3376, -0.7963,  0.8304,  0.3821, -0.3199,  0.0244],
            [ 0.8510, -0.3599, -0.6366, -0.9556,  0.9273,  0.3357,  0.1457,  0.3025],
            [ 0.8583, -0.4251, -0.6412, -0.9593,  0.9302,  0.3588,  0.1705,  0.3394],
            [ 0.3930,  0.7486,  0.0825, -0.1555,  0.1927, -0.4077, -0.3563, -0.6264],
            [ 0.9382,  0.7612,  0.2074, -0.7751,  0.6090, -0.7778,  0.7459, -0.5338],
            [ 0.7345,  0.4647,  0.3572, -0.4922,  0.4445, -0.3764,  0.0423, -0.1379],
            [ 0.5474,  0.6905,  0.1789, -0.2973, -0.0790, -0.5587, -0.3300, -0.2668],
            [ 0.9280,  0.7770,  0.1928, -0.7777,  0.5854, -0.7682,  0.7138, -0.5229],
            [ 0.7093,  0.5077,  0.3272, -0.4869,  0.4249, -0.3667, -0.0038, -0.1663],
            [ 0.7258,  0.4849,  0.3400, -0.4955,  0.4330, -0.3742,  0.0242, -0.1371],
            [ 0.9171,  0.8391,  0.0225, -0.7376,  0.5398, -0.8142,  0.7290, -0.7934],
            [ 0.9258,  0.7765,  0.3642, -0.7576,  0.6806, -0.7491,  0.7744, -0.7603],
            [ 0.9954,  0.8601,  0.2062, -0.9487,  0.8624, -0.9523,  0.9880, -0.8740],
            [ 0.9320,  0.7798,  0.1400, -0.6574,  0.6326, -0.8292,  0.8132, -0.8488],
            [ 0.9255,  0.7795,  0.3575, -0.7542,  0.6814, -0.7493,  0.7747, -0.7654],
            [ 0.9952,  0.8678,  0.1843, -0.9483,  0.8586, -0.9517,  0.9874, -0.8764],
            [ 0.9954,  0.8627,  0.1992, -0.9484,  0.8619, -0.9519,  0.9878, -0.8756],
            [ 0.9561,  0.8275,  0.2489, -0.8477,  0.6742, -0.7943,  0.8440, -0.7589],
            [ 0.9541,  0.8267,  0.3492, -0.8365,  0.7170, -0.8217,  0.8681, -0.7829],
            [ 0.2107,  0.6971,  0.4662, -0.1486,  0.2935, -0.2319, -0.4283, -0.5989],
            [ 0.9556,  0.8180,  0.2660, -0.8563,  0.6727, -0.7986,  0.8433, -0.7441],
            [ 0.9539,  0.8271,  0.3458, -0.8374,  0.7158, -0.8207,  0.8670, -0.7819],
            [ 0.2107,  0.6980,  0.4604, -0.1533,  0.2908, -0.2252, -0.4337, -0.5959],
            [ 0.2104,  0.6976,  0.4642, -0.1504,  0.2924, -0.2296, -0.4304, -0.5978]],
           grad_fn=<CatBackward>), batch_sizes=tensor([7, 7, 7, 7, 7]))
    111 (tensor([[[-0.1825,  0.9101, -0.8702, -0.6278,  0.2190,  0.2605, -0.7234,
              -0.0163],
             [ 0.5106,  0.9338, -0.8618,  0.2582,  0.9622,  0.6992,  0.8183,
               0.8412],
             [ 0.9439,  0.7868, -0.9984, -0.9804,  0.9951,  0.8878,  0.9840,
               0.9986],
             [ 0.8220,  0.7124,  0.6617,  0.3772, -0.0475, -0.8009,  0.1510,
              -0.0459],
             [ 0.7401,  0.9499, -0.9404, -0.3977,  0.7578,  0.5945,  0.6236,
               0.9589],
             [ 0.8393,  0.9940, -0.9980, -0.9685,  0.9733,  0.8201,  0.8973,
               0.9990],
             [ 0.9496,  0.9072, -0.9972, -0.9506,  0.9938,  0.8273,  0.9807,
               0.9951]],
    
            [[ 0.3781,  0.5988, -0.4116, -0.4798,  0.3336, -0.1607, -0.5875,
              -0.2394],
             [ 0.6116, -0.2055, -0.4288, -0.8675,  0.8222,  0.4301, -0.4239,
               0.2674],
             [ 0.8690, -0.4402, -0.6296, -0.9601,  0.9292,  0.3328,  0.2067,
               0.3849],
             [ 0.3107, -0.2493, -0.2772, -0.3696,  0.5967, -0.1408, -0.3301,
              -0.6034],
             [ 0.5993, -0.1124, -0.3376, -0.7963,  0.8304,  0.3821, -0.3199,
               0.0244],
             [ 0.8510, -0.3599, -0.6366, -0.9556,  0.9273,  0.3357,  0.1457,
               0.3025],
             [ 0.8583, -0.4251, -0.6412, -0.9593,  0.9302,  0.3588,  0.1705,
               0.3394]],
    
            [[ 0.3930,  0.7486,  0.0825, -0.1555,  0.1927, -0.4077, -0.3563,
              -0.6264],
             [ 0.9382,  0.7612,  0.2074, -0.7751,  0.6090, -0.7778,  0.7459,
              -0.5338],
             [ 0.7345,  0.4647,  0.3572, -0.4922,  0.4445, -0.3764,  0.0423,
              -0.1379],
             [ 0.5474,  0.6905,  0.1789, -0.2973, -0.0790, -0.5587, -0.3300,
              -0.2668],
             [ 0.9280,  0.7770,  0.1928, -0.7777,  0.5854, -0.7682,  0.7138,
              -0.5229],
             [ 0.7093,  0.5077,  0.3272, -0.4869,  0.4249, -0.3667, -0.0038,
              -0.1663],
             [ 0.7258,  0.4849,  0.3400, -0.4955,  0.4330, -0.3742,  0.0242,
              -0.1371]],
    
            [[ 0.9171,  0.8391,  0.0225, -0.7376,  0.5398, -0.8142,  0.7290,
              -0.7934],
             [ 0.9258,  0.7765,  0.3642, -0.7576,  0.6806, -0.7491,  0.7744,
              -0.7603],
             [ 0.9954,  0.8601,  0.2062, -0.9487,  0.8624, -0.9523,  0.9880,
              -0.8740],
             [ 0.9320,  0.7798,  0.1400, -0.6574,  0.6326, -0.8292,  0.8132,
              -0.8488],
             [ 0.9255,  0.7795,  0.3575, -0.7542,  0.6814, -0.7493,  0.7747,
              -0.7654],
             [ 0.9952,  0.8678,  0.1843, -0.9483,  0.8586, -0.9517,  0.9874,
              -0.8764],
             [ 0.9954,  0.8627,  0.1992, -0.9484,  0.8619, -0.9519,  0.9878,
              -0.8756]],
    
            [[ 0.9561,  0.8275,  0.2489, -0.8477,  0.6742, -0.7943,  0.8440,
              -0.7589],
             [ 0.9541,  0.8267,  0.3492, -0.8365,  0.7170, -0.8217,  0.8681,
              -0.7829],
             [ 0.2107,  0.6971,  0.4662, -0.1486,  0.2935, -0.2319, -0.4283,
              -0.5989],
             [ 0.9556,  0.8180,  0.2660, -0.8563,  0.6727, -0.7986,  0.8433,
              -0.7441],
             [ 0.9539,  0.8271,  0.3458, -0.8374,  0.7158, -0.8207,  0.8670,
              -0.7819],
             [ 0.2107,  0.6980,  0.4604, -0.1533,  0.2908, -0.2252, -0.4337,
              -0.5959],
             [ 0.2104,  0.6976,  0.4642, -0.1504,  0.2924, -0.2296, -0.4304,
              -0.5978]]], grad_fn=<CopySlices>), tensor([5, 5, 5, 5, 5, 5, 5]))
    
    Process finished with exit code 0
    
    下图表示有7个batch,5个时间步,每个时间步的输入是7个batch, image.png

    pad_packed_sequence输出的是 RNN 每个时间步的输出,我觉得图1的输出就是图2。但是又有一个问题,那就是 RNN 内部是如何处理这一个时间步的输入的?从输入维度为3到输出维度为8,我知道batch=1时如何处理(一个维度为3的向量经过几个全连接网络变成了维度为8的向量),我推测 batch=7 其实就是和 batch=1 是一样的,7 个维度为 3 的向量共享权值,下一个时间步也同样如此。最终等 batch 输入完毕后,计算出 loss ,再反向传播。


    图1 图2

    相关文章

      网友评论

          本文标题:pytorch中处理变长序列

          本文链接:https://www.haomeiwen.com/subject/nhipoctx.html