美文网首页
2020-05-14pytorch之stack、cat、tran

2020-05-14pytorch之stack、cat、tran

作者: lzjngu | 来源:发表于2020-05-14 21:49 被阅读0次

    stack
    使用stack是为了保留两个信息: 序列(先后)和 张量矩阵信息。比如在循环神经网络中,网络的输出数据通常是:包含了n个数据大小[batch_size, num_outputs]的list,这个和[n, batch_size, num_outputs]是完全不一样的!!!!不利于计算,需要使用stack进行拼接,保留–[1.时间步]和–[2.张量的矩阵乘积属性]。

    官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
    浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

    形式:
    outputs = torch.stack(inputs, dim=0) → Tensor
    重点

    1. 函数中的输入inputs只允许是list或tuple;且序列内部的张量元素,必须shape相等
      ----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

    2. dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
      例子:

    x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
    x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
    x3 = torch.tensor([[13,23,33],[23,33,43]],dtype=torch.int)
    x4 = torch.tensor([[14,24,34],[24,34,44]],dtype=torch.int)
    
    inputs = [x1, x2, x3, x4]
    In [19]: torch.stack(inputs, dim=0)
    Out[19]: 
    tensor([[[11, 21, 31],
             [21, 31, 41]],
    
            [[12, 22, 32],
             [22, 32, 42]],
    
            [[13, 23, 33],
             [23, 33, 43]],
    
            [[14, 24, 34],
             [24, 34, 44]]], dtype=torch.int32)
    
    In [21]: torch.stack(inputs, dim=1)
    Out[21]: 
    tensor([[[11, 21, 31],
             [12, 22, 32],
             [13, 23, 33],
             [14, 24, 34]],
    
            [[21, 31, 41],
             [22, 32, 42],
             [23, 33, 43],
             [24, 34, 44]]], dtype=torch.int32)
    
    In [20]: torch.stack(inputs, dim=2)
    Out[20]: 
    tensor([[[11, 12, 13, 14],
             [21, 22, 23, 24],
             [31, 32, 33, 34]],
    
            [[21, 22, 23, 24],
             [31, 32, 33, 34],
             [41, 42, 43, 44]]], dtype=torch.int32)
    
    aa = torch.tensor([[[1,2,3],[4,5,6],[7, 8,9]]])
    bb = torch.tensor([[[11, 21, 31],[41,51,61],[71,81,91]]])
    cc = torch.tensor([[[101,201,301],[401,501,601],[701,801,901]]])
    
    inputs1 = [aa, bb, cc]
    In [29]: torch.stack(inputs1, dim=0)
    Out[29]: 
    tensor([[[[  1,   2,   3],
              [  4,   5,   6],
              [  7,   8,   9]]],
    
            [[[ 11,  21,  31],
              [ 41,  51,  61],
              [ 71,  81,  91]]],
    
            [[[101, 201, 301],
              [401, 501, 601],
              [701, 801, 901]]]])
    
    In [30]: torch.stack(inputs1, dim=1)
    Out[30]: 
    tensor([[[[  1,   2,   3],
              [  4,   5,   6],
              [  7,   8,   9]],
    
             [[ 11,  21,  31],
              [ 41,  51,  61],
              [ 71,  81,  91]],
    
             [[101, 201, 301],
              [401, 501, 601],
              [701, 801, 901]]]])
    
    In [31]: torch.stack(inputs1, dim=2)
    Out[31]: 
    tensor([[[[  1,   2,   3],
              [ 11,  21,  31],
              [101, 201, 301]],
    
             [[  4,   5,   6],
              [ 41,  51,  61],
              [401, 501, 601]],
    
             [[  7,   8,   9],
              [ 71,  81,  91],
              [701, 801, 901]]]])
    
    In [32]: torch.stack(inputs1, dim=3)
    Out[32]: 
    tensor([[[[  1,  11, 101],
              [  2,  21, 201],
              [  3,  31, 301]],
    
             [[  4,  41, 401],
              [  5,  51, 501],
              [  6,  61, 601]],
    
             [[  7,  71, 701],
              [  8,  81, 801],
              [  9,  91, 901]]]])
    
    In [33]: torch.stack(inputs1, dim=-1)
    Out[33]: 
    tensor([[[[  1,  11, 101],
              [  2,  21, 201],
              [  3,  31, 301]],
    
             [[  4,  41, 401],
              [  5,  51, 501],
              [  6,  61, 601]],
    
             [[  7,  71, 701],
              [  8,  81, 801],
              [  9,  91, 901]]]])
    

    Cat
    对数据沿着某一维度进行拼接。cat后数据的总维数不变.

    In [34]: x = torch.randn(2,3)
    In [35]: y = torch.randn(1,3)
    In [37]: print(x, '\n', y)
    tensor([[ 1.8932,  0.8820, -0.3152],
            [ 0.4488,  1.7583, -0.0939]]) 
     tensor([[-1.0298,  0.8602, -0.5422]])
    
    In [38]: torch.cat((x, y), dim=0)
    Out[38]: 
    tensor([[ 1.8932,  0.8820, -0.3152],
            [ 0.4488,  1.7583, -0.0939],
            [-1.0298,  0.8602, -0.5422]])
    

    transpose
    transpose ,交换维度

    In [39]: x = torch.randn(2, 3)
    In [40]: print(x)
    tensor([[ 1.5418,  0.8280, -0.8068],
            [-0.3803, -1.1618,  1.4929]])
    
    In [41]: x.transpose(0, 1)
    Out[41]: 
    tensor([[ 1.5418, -0.3803],
            [ 0.8280, -1.1618],
            [-0.8068,  1.4929]])
    
    In [42]: x.transpose(1, 0)
    Out[42]: 
    tensor([[ 1.5418, -0.3803],
            [ 0.8280, -1.1618],
            [-0.8068,  1.4929]])
    

    permute
    permute,适合多维数据,permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

    In [43]: x = torch.randn(2,3,4)
    In [44]: xp = x.permute(1, 0, 2)
    
    In [45]: print(x)
    tensor([[[-0.4044, -0.4237,  0.2973, -1.5864],
             [ 0.7312, -0.9954, -1.2718,  0.0916],
             [ 0.3418,  1.1162,  0.8982,  0.6203]],
    
            [[ 0.9823, -1.3540,  1.0551,  1.5960],
             [ 1.5930, -0.3035, -0.3781,  1.3462],
             [ 1.1224,  0.6163, -1.3140, -1.1987]]])
    
    In [46]: print(xp)
    tensor([[[-0.4044, -0.4237,  0.2973, -1.5864],
             [ 0.9823, -1.3540,  1.0551,  1.5960]],
    
            [[ 0.7312, -0.9954, -1.2718,  0.0916],
             [ 1.5930, -0.3035, -0.3781,  1.3462]],
    
            [[ 0.3418,  1.1162,  0.8982,  0.6203],
             [ 1.1224,  0.6163, -1.3140, -1.1987]]])
    

    squeeze 和 unsqueeze
    squeeze(dim_n), 压缩,即去掉元素数量为1的dim_n维度。同理unsqueeze(dim_n),增加dim_n维度,元素数量为1。

    # 定义张量
    import torch
    
    b = torch.Tensor(2,1)
    b.shape
    Out[28]: torch.Size([2, 1])
    
    # 不加参数,去掉所有为元素个数为1的维度
    b_ = b.squeeze()
    b_.shape
    Out[30]: torch.Size([2])
    
    # 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
    b_ = b.squeeze(0)
    b_.shape 
    Out[32]: torch.Size([2, 1])
    
    # 这样就可以了
    b_ = b.squeeze(1)
    b_.shape
    Out[34]: torch.Size([2])
    
    # 增加一个维度
    b_ = b.unsqueeze(2)
    b_.shape
    Out[36]: torch.Size([2, 1, 1])
    

    **self.scatter(dim, index, src) **
    从张量src中按照index张量中指定的索引位置写入self张量的值。对于一个三维张量,self更新为:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
    

    为了保证scatter填充的有效性,需要注意:
    (1)self张量在dim方向上的长度不小于source张量,且在其它轴方向的长度与source张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
    (2)index张量的shape一般与source ,从而定义了每个source元素的填充位置。这里的一般是指broadcast机制下的例外情况。

    import torch
    a = torch.arange(10).reshape(2,5).float()
    b = torch.zeros(3, 5))
    index = torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]])
    b_= b.scatter(dim=0, index=index,src=a)
    print(b_)
    
    # tensor([[0, 6, 0, 0, 9],
    #        [0, 0, 2, 8, 0],
    #        [5, 1, 7, 0, 4]])
    
    a = torch.arange(10).reshape(2,5).float()
    #tensor([[0., 1., 2., 3., 4.],
    #        [5., 6., 7., 8., 9.]])
    ind = torch.LongTensor([[1, 2, 1, 1, 2]])
    c = b.scatter(0, ind, a)
    #tensor([[0., 0., 0., 0., 0.],
    #        [0., 0., 2., 3., 0.],
    #        [0., 1., 0., 0., 4.]])
    

    scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:

    labels = torch.LongTensor([1,3])
    targets = torch.zeros(2, 5)
    targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
    # 注意dim=1,即逐样本的进行列填充
    # 返回值为 tensor([[0, 1, 0, 0, 0],
    #                 [0, 0, 0, 1, 0]])
    

    gather
    函数torch.gather(input, dim, index, out=None) → Tensor
    沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
    对一个 3 维张量,输出可以定义为:

    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
    

    Parameters:

    • input (Tensor) – 源张量
    • dim (int) – 索引的轴
    • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
    • out (Tensor, optional) – 目标张量

    使用说明举例:

    1. dim = 1
    import torch
    a = torch.randint(0, 30, (2, 3, 5))
    print(a)
    '''
    tensor([[[ 18.,   5.,   7.,   1.,   1.],
             [  3.,  26.,   9.,   7.,   9.],
             [ 10.,  28.,  22.,  27.,   0.]],
    
            [[ 26.,  10.,  20.,  29.,  18.],
             [  5.,  24.,  26.,  21.,   3.],
             [ 10.,  29.,  10.,   0.,  22.]]])
    '''
    index = torch.LongTensor([[[0,1,2,0,2],
                              [0,0,0,0,0],
                              [1,1,1,1,1]],
                            [[1,2,2,2,2],
                             [0,0,0,0,0],
                             [2,2,2,2,2]]])
    print(a.size()==index.size())
    b = torch.gather(a, 1,index)
    print(b)
    '''
    True
    tensor([[[ 18.,  26.,  22.,   1.,   0.],
             [ 18.,   5.,   7.,   1.,   1.],
             [  3.,  26.,   9.,   7.,   9.]],
    
            [[  5.,  29.,  10.,   0.,  22.],
             [ 26.,  10.,  20.,  29.,  18.],
             [ 10.,  29.,  10.,   0.,  22.]]])
    可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,
    第一行分别是 按照index指定的,
    input tensor的第一页 
    第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,
    第五列下标为2元素
    index-->0,1,2,0,2    output--> 18.,  26.,  22.,   1.,   0.
    '''
    
    1. dim =2
    c = torch.gather(a, 2,index)
    print(c)
    '''
    tensor([[[ 18.,   5.,   7.,  18.,   7.],
             [  3.,   3.,   3.,   3.,   3.],
             [ 28.,  28.,  28.,  28.,  28.]],
    
            [[ 10.,  20.,  20.,  20.,  20.],
             [  5.,   5.,   5.,   5.,   5.],
             [ 10.,  10.,  10.,  10.,  10.]]])
    dim = 2的时候就安装 行 聚合了。参照上面的举一反三。
    '''
    
    1. dim = 0
    index2 = torch.LongTensor([[[0,1,1,0,1],
                              [0,1,1,1,1],
                              [1,1,1,1,1]],
                            [[1,0,0,0,0],
                             [0,0,0,0,0],
                             [1,1,0,0,0]]])
    d = torch.gather(a, 0,index2)
    print(d)
    '''
    tensor([[[ 18.,  10.,  20.,   1.,  18.],
             [  3.,  24.,  26.,  21.,   3.],
             [ 10.,  29.,  10.,   0.,  22.]],
    
            [[ 26.,   5.,   7.,   1.,   1.],
             [  3.,  26.,   9.,   7.,   9.],
             [ 10.,  29.,  22.,  27.,   0.]]])
    这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。
    这里举的例子只有两页。所有index在0,1两个之间选择。
    输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。
    index [0,1,1,0,1]的意思就是。
    在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。
    '''
    

    转载或参考链接:
    https://blog.csdn.net/excellent_sun/article/details/95175823
    https://www.cnblogs.com/yifdu25/p/9399047.html
    https://www.cnblogs.com/dogecheng/p/11938009.html
    https://www.jianshu.com/p/5d1f8cd5fe31

    相关文章

      网友评论

          本文标题:2020-05-14pytorch之stack、cat、tran

          本文链接:https://www.haomeiwen.com/subject/mlmxohtx.html