美文网首页Pytorch
1. Pytorch 计算中的维度

1. Pytorch 计算中的维度

作者: yoyo9999 | 来源:发表于2021-03-22 13:57 被阅读0次

    1. dim

    1. torch.mean()

    Pytorch中维度从前往后依次为0, 1,...;反向为-1, -2,...。
    torch.mean(input, dim, keepdim=False, *, out=None) → Tensor中, 如果dim=-1, 就是用最后一个维度作为特征,来计算均值。 keepdim=True代表不进行维度缩减

    import torch
    
    x = torch.rand((1, 2, 4, 3))
    print(x)
    print(x.shape, x.shape[0], x.shape[1], x.shape[-1])
    
    x_mean = x.mean(dim=-1, keepdim=True)
    print(x_mean)
    
    ###########################################################
    tensor([[[[0.9096, 0.3786, 0.2141],
              [0.7117, 0.7631, 0.9907],
              [0.1136, 0.3949, 0.7348],
              [0.4267, 0.6413, 0.2369]],
    
             [[0.3865, 0.5671, 0.6057],
              [0.9800, 0.7702, 0.8712],
              [0.5558, 0.0272, 0.1465],
              [0.0296, 0.2917, 0.4655]]]])
    torch.Size([1, 2, 4, 3]) 1 2 3
    tensor([[[[0.5008],
              [0.8218],
              [0.4144],
              [0.4350]],
    
             [[0.5198],
              [0.8738],
              [0.2432],
              [0.2623]]]])
    
    

    2. torch.view(), torch.transpose() 和 torch.reshape()

    再来看看 torch.view(), 哪个维度是-1就是推断出来的维度,其值等于总的维度除已经指定的维度。比如下面例子中, \frac{1\times2\times4\times3}{8}=3

    x_view = x.view([-1, 8])
    print(x_view)
    print(x_view.shape)
    
    ###########################################################
    tensor([[0.7176, 0.1339, 0.8430, 0.7403, 0.8586, 0.4462, 0.2916, 0.8481],
            [0.6150, 0.2353, 0.8000, 0.8882, 0.1117, 0.3792, 0.6269, 0.0883],
            [0.8065, 0.9425, 0.5607, 0.0641, 0.7079, 0.5646, 0.5847, 0.0929]])
    torch.Size([3, 8])
    

    torch.transpose(input, dim0, dim1) -> Tensor 是交换轴,比如1, 2就是交换第2个维度到第一个维度。它与view的区别是,会重新拷贝一份,所以内存会改变,不再是同一个tensor。view是在原来的内存上重新view新的形状。所以下面x_view 和 x_trans虽然形状一样,但id不一样,tensor也不相等。

    print(x.shape, id(x))
    x_trans = x.transpose(1, 2)
    print(x_trans)
    print(x_trans.shape)
    x_view = x.view((1, 4, 2, 3))
    print(x_view.shape)
    print(id(x_view))
    print(torch.equal(x_trans, x_view))
    
    
    ########################################
    torch.Size([1, 2, 4, 3]) 1852810812928
    tensor([[[[0.1892, 0.3858, 0.4969],
              [0.3069, 0.2180, 0.3057]],
    
             [[0.7929, 0.8586, 0.7825],
              [0.6899, 0.3481, 0.2375]],
    
             [[0.5666, 0.1618, 0.3767],
              [0.4373, 0.0928, 0.4095]],
    
             [[0.7004, 0.7312, 0.6594],
              [0.6045, 0.8708, 0.0987]]]])
    torch.Size([1, 4, 2, 3])
    torch.Size([1, 4, 2, 3])
    1852852438080
    False
    

    torch.reshape(input, shape) → Tensor跟view的使用一样, reshape是后面引入的,两者区别在于view要求tensor在内存分布上是contiguous。什么是 contiguous看Inference [1]。

    x_reshape = x.reshape((-1, 8))
    print(x_reshape.shape)
    ##############################
    torch.Size([3, 8])
    

    3. tensor.permute()

    tensor.permute()是按照索引来重新,排列tensor。如下面的从 1, 2, 3, 4其索引为 0, 1, 2, 3如果要按照索引 1, 0, 3, 2排,就变成维度为 2, 1, 3, 4了.

    torch.transpose()区别是:

    • 它只能是 tensor.permute(),不能是 torch.permute()
    • 可以交换多个维度, torch.transpose()只能交换两个
    print(x.shape)
    x_per = x.permute((1, 0, 3, 2))
    print(x_per.shape)
    
    ##########################
    torch.Size([1, 2, 4, 3])
    torch.Size([2, 1, 3, 4])
    

    4. torch.squeeze()/ unsqueeze()

    torch.squeeze(input, dim=None, *, out=None) → Tensor 压缩指定维度中为1的维度, 如果不指定就把所有为1的都压缩掉。

    x = x.reshape((1, 2, 1, 3, 4, 1))
    print(x.shape)
    x_squeeze = x.squeeze()
    print(x_squeeze.shape)
    x_squeeze_0 = x.squeeze(0)
    print(x_squeeze_0.shape)
    
    #################################
    torch.Size([1, 2, 1, 3, 4, 1])
    torch.Size([2, 3, 4])
    torch.Size([2, 1, 3, 4, 1])
    
    

    torch.unsqueeze(input, dim) → Tensor 是在指定维度之前插入1的维度,来满足运算需要。x_squeeze索引为1的维度为3, 那么在其前面加上1 dim.

    print(x_squeeze.shape)
    x_uns_1 = x_squeeze.unsqueeze(1)
    print(x_uns_1.shape)
    ###################################
    torch.Size([2, 3, 4])
    torch.Size([2, 1, 3, 4])
    

    Inference

    [1] pytorch-contiguous
    [2] Pytorch之permute函数
    [3] PyTorch的 transpose、permute、view、reshape

    相关文章

      网友评论

        本文标题:1. Pytorch 计算中的维度

        本文链接:https://www.haomeiwen.com/subject/zuuzcltx.html