美文网首页Pytorch
torch.squeeze()和torch.unsqueeze(

torch.squeeze()和torch.unsqueeze(

作者: 西北小生_ | 来源:发表于2019-07-23 11:10 被阅读0次
    1. torch.squeeze(tensor)

    和numpy等库函数中的squeeze()函数作用一样,torch.squeeze()函数的作用是压缩一个tensor的维数为1的维度,使该tensor降维变成最紧凑的形式:

    In [1]: import numpy as np                                                      
    
    In [2]: import torch                                                            
    
    In [3]: a = torch.arange(9).view(3,1,3)                                         
    
    In [4]: a                                                                       
    Out[4]: 
    tensor([[[0, 1, 2]],
    
            [[3, 4, 5]],
    
            [[6, 7, 8]]])
    
    In [5]: a.size()                                                                
    Out[5]: torch.Size([3, 1, 3])
    
    In [6]: a.dim()                                                                 
    Out[6]: 3
    
    In [7]: b = torch.squeeze(a)                                                    
    
    In [8]: b                                                                       
    Out[8]: 
    tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    
    In [9]: b.size()                                                                
    Out[9]: torch.Size([3, 3])
    
    In [10]: b.dim()                                                                
    Out[10]: 2
    

    同样numpy中功能一样:

    In [11]: c = np.arange(9).reshape(1,3,1,3)                                      
    
    In [12]: c                                                                      
    Out[12]: 
    array([[[[0, 1, 2]],
    
            [[3, 4, 5]],
    
            [[6, 7, 8]]]])
    
    In [13]: c.shape, c.ndim                                                        
    Out[13]: ((1, 3, 1, 3), 4)
    
    In [14]: d = np.squeeze(c)                                                      
    
    In [15]: d                                                                      
    Out[15]: 
    array([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
    
    In [16]: d.shape, d.ndim                                                        
    Out[16]: ((3, 3), 2)
    
    2. torch.unsqueeze(tensor, dim)

    unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。

    In [17]: b                                                                      
    Out[17]: 
    tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    
    In [18]: b.size(), b.dim()                                                      
    Out[18]: (torch.Size([3, 3]), 2)
    
    In [20]: b_un = torch.unsqueeze(b, 0)                                           
    
    In [21]: b_un                                                                   
    Out[21]: 
    tensor([[[0, 1, 2],
             [3, 4, 5],
             [6, 7, 8]]])
    
    In [22]: b_un.size(), b_un.dim()                                                
    Out[22]: (torch.Size([1, 3, 3]), 3)
    
    In [23]: b_un_un = torch.unsqueeze(b_un, 3)                                     
    
    In [24]: b_un_un                                                                
    Out[24]: 
    tensor([[[[0],
              [1],
              [2]],
    
             [[3],
              [4],
              [5]],
    
             [[6],
              [7],
              [8]]]])
    
    In [25]: b_un_un.size(), b_un_un.dim()                                          
    Out[25]: (torch.Size([1, 3, 3, 1]), 4)
    

    相关文章

      网友评论

        本文标题:torch.squeeze()和torch.unsqueeze(

        本文链接:https://www.haomeiwen.com/subject/wbxolctx.html