我们整理一下tensor的常见的处理函数。包括拆分(Split)、合并(Cat)、Stack、Chunk
合并(Cat)
和TensorFlow的tf.concat类似。
torch.cat([a , b] , dim),合并tensor a和b,dim指的是从哪个维度。其他维度需要保持一致,如果不一致会出错。
batch_1 = torch.rand(2,3,28,28)
batch_2 = torch.rand(5,3,28,28)
torch.cat([batch_1,batch_2],dim=0).shape
#torch.Size([7, 3, 28, 28])
stack
stack 与 concat 不同之处,会增加一个维度用于区分合并的不同 tensor。需要要合并两个 tensor 形状完全一致,而 dim=2 维度前添加一个维度。
batch_1 = torch.rand(2,3,16,32)
batch_2 = torch.rand(2,3,16,32)
torch.stack([batch_1,batch_2],dim=2).shape
#torch.Size([2, 3, 2, 16, 32])
grp_1 = torch.rand(32,8)
grp_2 = torch.rand(32,8)
torch.stack([grp_1,grp_2],dim=0).shape
# torch.Size([2, 32, 8])
split
c = torch.rand(3,32,8)
grp_1,grp_2 = c.split([1,2],dim=0)
print(grp_1.shape)
print(grp_2.shape)
#torch.Size([1, 32, 8])
#torch.Size([2, 32, 8])
c = torch.rand(4,32,8)
grp_1,grp_2 = c.split([2,dim=0)
print(grp_1.shape)
print(grp_2.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
chunk
# chunk 按数量进行拆分
grp_1,grp_2,grp_3 = c.chunk(3,dim=0)
print(grp_1.shape)
print(grp_2.shape)
print(grp_3.shape)
'''
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
'''
view
类似于numpy中的resize,改变tensor的size。
import torch
tt1=torch.tensor([-0.3623,-0.6115,0.7283,0.4699,2.3261,0.1599])
result=tt1.view(3,2)
输出
tensor([[-0.3623, -0.6115],
[ 0.7283, 0.4699],
[ 2.3261, 0.1599]])
size
Tensor.szie()可以获取tensor的形状。
网友评论