torch.cat 的作用是把两个 tensor 合并为一个 tensor
第一个参数是需要连接的tensor list , 第二个参数指定按照哪个维度进行拼接
import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)
print(A)
B=torch.ones(3,5)
print(B)
list=[]
list.append(A)
list.append(B)
C=torch.cat(list,dim=0)#按照行进行拼接,此时所有tensor的列数需要相同
print(C,C.shape)
>>
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]) torch.Size([5, 5])
>>
按照列进行拼接
import torch
A=torch.zeros(2,5) #2x5的张量(矩阵)
print(A)
B=torch.ones(2,5)
print(B)
list=[]
list.append(A)
list.append(B)
C=torch.cat(list,dim=1)#按照列进行拼接,此时的tensor 行数必须一致
#C=torch.cat((A,B),dim=1)
print(C,C.shape)
>>
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]]) torch.Size([2, 10])
>>
这个函数的应用是基础的,很多任务,例如要处理的数据是一行一行的,分别都转换为 tensor 之后,需要将全部的句子都拼接起来,然后再分成 batch 批量输入模型,所以需要用到 cat 的操作;
网友评论