torch.cat

作者: 三方斜阳 | 来源:发表于2021-02-04 16:44 被阅读0次

    torch.cat 的作用是把两个 tensor 合并为一个 tensor
    第一个参数是需要连接的tensor list , 第二个参数指定按照哪个维度进行拼接

    import torch
    A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
    print(A)
    B=torch.ones(3,5)
    print(B)
    list=[]
    list.append(A)
    list.append(B)
    C=torch.cat(list,dim=0)#按照行进行拼接,此时所有tensor的列数需要相同
    print(C,C.shape)
    
    >>
    tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]])
    tensor([[1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]])
    tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]]) torch.Size([5, 5])
    >>
    

    按照列进行拼接

    import torch
    A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
    print(A)
    B=torch.ones(2,5)
    print(B)
    list=[]
    list.append(A)
    list.append(B)
    
    C=torch.cat(list,dim=1)#按照列进行拼接,此时的tensor 行数必须一致
    #C=torch.cat((A,B),dim=1)
    print(C,C.shape)
    >>
    tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]])
    tensor([[1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]])
    tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
            [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]) torch.Size([2, 10])
    >>
    
    

    这个函数的应用是基础的,很多任务,例如要处理的数据是一行一行的,分别都转换为 tensor 之后,需要将全部的句子都拼接起来,然后再分成 batch 批量输入模型,所以需要用到 cat 的操作;

    相关文章

      网友评论

          本文标题:torch.cat

          本文链接:https://www.haomeiwen.com/subject/eoibtltx.html