美文网首页
pytorch的Tensor的操作

pytorch的Tensor的操作

作者: 术枚派 | 来源:发表于2021-10-18 21:18 被阅读0次

    我们整理一下tensor的常见的处理函数。包括拆分(Split)、合并(Cat)、Stack、Chunk

    合并(Cat)

    和TensorFlow的tf.concat类似。
    torch.cat([a , b] , dim),合并tensor a和b,dim指的是从哪个维度。其他维度需要保持一致,如果不一致会出错。

    batch_1 = torch.rand(2,3,28,28)
    batch_2 = torch.rand(5,3,28,28)
    torch.cat([batch_1,batch_2],dim=0).shape
    #torch.Size([7, 3, 28, 28])
    

    stack

    stack 与 concat 不同之处,会增加一个维度用于区分合并的不同 tensor。需要要合并两个 tensor 形状完全一致,而 dim=2 维度前添加一个维度。

    batch_1 = torch.rand(2,3,16,32)
    batch_2 = torch.rand(2,3,16,32)
    torch.stack([batch_1,batch_2],dim=2).shape
    #torch.Size([2, 3, 2, 16, 32])
    
    grp_1 = torch.rand(32,8)
    grp_2 = torch.rand(32,8)
    torch.stack([grp_1,grp_2],dim=0).shape
    # torch.Size([2, 32, 8])
    

    split

    c = torch.rand(3,32,8)
    grp_1,grp_2 = c.split([1,2],dim=0)
    print(grp_1.shape)
    print(grp_2.shape)
    
    #torch.Size([1, 32, 8])
    #torch.Size([2, 32, 8])
    
    c = torch.rand(4,32,8)
    grp_1,grp_2 = c.split([2,dim=0)
    print(grp_1.shape)
    print(grp_2.shape)
    
    #torch.Size([2, 32, 8])
    #torch.Size([2, 32, 8])
    

    chunk

    # chunk 按数量进行拆分
    grp_1,grp_2,grp_3 = c.chunk(3,dim=0)
    print(grp_1.shape)
    print(grp_2.shape)
    print(grp_3.shape)
    
    '''
    torch.Size([1, 32, 8])
    torch.Size([1, 32, 8])
    torch.Size([1, 32, 8])
    '''
    

    view

    类似于numpy中的resize,改变tensor的size。

    import torch
    tt1=torch.tensor([-0.3623,-0.6115,0.7283,0.4699,2.3261,0.1599])
    result=tt1.view(3,2)
    

    输出

    tensor([[-0.3623, -0.6115],
            [ 0.7283,  0.4699],
            [ 2.3261,  0.1599]])
    

    size

    Tensor.szie()可以获取tensor的形状。

    引用

    pytorch 合并和拆分

    相关文章

      网友评论

          本文标题:pytorch的Tensor的操作

          本文链接:https://www.haomeiwen.com/subject/ldznoltx.html