torch.cat

作者: 三方斜阳 | 来源:发表于2021-02-04 16:44 被阅读0次

torch.cat 的作用是把两个 tensor 合并为一个 tensor
第一个参数是需要连接的tensor list , 第二个参数指定按照哪个维度进行拼接

import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
print(A)
B=torch.ones(3,5)
print(B)
list=[]
list.append(A)
list.append(B)
C=torch.cat(list,dim=0)#按照行进行拼接,此时所有tensor的列数需要相同
print(C,C.shape)

>>
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]) torch.Size([5, 5])
>>

按照列进行拼接

import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)                                     
print(A)
B=torch.ones(2,5)
print(B)
list=[]
list.append(A)
list.append(B)

C=torch.cat(list,dim=1)#按照列进行拼接,此时的tensor 行数必须一致
#C=torch.cat((A,B),dim=1)
print(C,C.shape)
>>
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]) torch.Size([2, 10])
>>

这个函数的应用是基础的,很多任务,例如要处理的数据是一行一行的,分别都转换为 tensor 之后,需要将全部的句子都拼接起来,然后再分成 batch 批量输入模型,所以需要用到 cat 的操作;

相关文章

网友评论

      本文标题:torch.cat

      本文链接:https://www.haomeiwen.com/subject/eoibtltx.html