美文网首页
pytorch的Tensor的操作

pytorch的Tensor的操作

作者: 术枚派 | 来源:发表于2021-10-18 21:18 被阅读0次

我们整理一下tensor的常见的处理函数。包括拆分(Split)、合并(Cat)、Stack、Chunk

合并(Cat)

和TensorFlow的tf.concat类似。
torch.cat([a , b] , dim),合并tensor a和b,dim指的是从哪个维度。其他维度需要保持一致,如果不一致会出错。

batch_1 = torch.rand(2,3,28,28)
batch_2 = torch.rand(5,3,28,28)
torch.cat([batch_1,batch_2],dim=0).shape
#torch.Size([7, 3, 28, 28])

stack

stack 与 concat 不同之处,会增加一个维度用于区分合并的不同 tensor。需要要合并两个 tensor 形状完全一致,而 dim=2 维度前添加一个维度。

batch_1 = torch.rand(2,3,16,32)
batch_2 = torch.rand(2,3,16,32)
torch.stack([batch_1,batch_2],dim=2).shape
#torch.Size([2, 3, 2, 16, 32])

grp_1 = torch.rand(32,8)
grp_2 = torch.rand(32,8)
torch.stack([grp_1,grp_2],dim=0).shape
# torch.Size([2, 32, 8])

split

c = torch.rand(3,32,8)
grp_1,grp_2 = c.split([1,2],dim=0)
print(grp_1.shape)
print(grp_2.shape)

#torch.Size([1, 32, 8])
#torch.Size([2, 32, 8])

c = torch.rand(4,32,8)
grp_1,grp_2 = c.split([2,dim=0)
print(grp_1.shape)
print(grp_2.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])

chunk

# chunk 按数量进行拆分
grp_1,grp_2,grp_3 = c.chunk(3,dim=0)
print(grp_1.shape)
print(grp_2.shape)
print(grp_3.shape)

'''
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
'''

view

类似于numpy中的resize,改变tensor的size。

import torch
tt1=torch.tensor([-0.3623,-0.6115,0.7283,0.4699,2.3261,0.1599])
result=tt1.view(3,2)

输出

tensor([[-0.3623, -0.6115],
        [ 0.7283,  0.4699],
        [ 2.3261,  0.1599]])

size

Tensor.szie()可以获取tensor的形状。

引用

pytorch 合并和拆分

相关文章

网友评论

      本文标题:pytorch的Tensor的操作

      本文链接:https://www.haomeiwen.com/subject/ldznoltx.html