
今天说一说 tensor 合并和拆分,下面是 pytorch 中用于对 tensor 进行拆分和合并 Api。Api 的设计也是一门艺术,好的 Api 不但用起来简单明了,看起来也简单明了,pytorch 的 Api 设计上作者一定也花费一些心思,所以用起来会感受她的美。

import torch
- Cat
- Stack
- Split
- Chunk
# concate
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
torch.cat([a,b],dim=0).shape
torch.Size([9, 32, 8])
# 后两个维度需要保持一致,需要 dim 一致,而且除了 concate 以外其他维度需要大小一致
concat
concat 用于将向量进行合并,concat 可以将两个 tensor 在某一个维度进行合并,这里有前提条件的,就是要合并的两个 tensor 在除了要合并的维度外的形状(维度)保持一致。
batch_1 = torch.rand(2,3,28,28)
batch_2 = torch.rand(5,3,28,28)
Api cat 表示 pytorch 也是简约型,api 设计简单明了,第一个参数是要合并的 tensor 集合,输入元素为 tensor 的list dim 表示要和并的维度,返回一个 tensor
torch.cat([batch_1,batch_2],dim=0).shape
torch.Size([7, 3, 28, 28])
在 batch_2 因为在 dim = 1 维度没有保持一致而无法合并。
batch_2 = torch.rand(2,1,28,28)
torch.cat([batch_1,batch_2],dim=0).shape
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-bb7d5074e41a> in <module>
----> 1 torch.cat([batch_1,batch_2],dim=0).shape
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /Users/distiller/project/conda/conda-bld/pytorch_1579022036889/work/aten/src/TH/generic/THTensor.cpp:612
torch.cat([batch_1,batch_2],dim=1).shape
torch.Size([2, 4, 28, 28])
batch_1 = torch.rand(2,3,16,32)
batch_2 = torch.rand(2,3,16,32)
torch.cat([batch_1,batch_2],dim=2).shape
torch.Size([2, 3, 32, 32])
stack
stack 与 concat 不同之处,会增加一个维度用于区分合并的不同 tensor。需要要合并两个 tensor 形状完全一致,而 dim=2 维度前添加一个维度。
# stack 添加一个新的维度 当 0 时为 batch_1 当为 1 时为 batch_2
# 创建一个新的维度,对于 stack 需要两个 tensor 形状完全一致
torch.stack([batch_1,batch_2],dim=2).shape
torch.Size([2, 3, 2, 16, 32])
grp_1 = torch.rand(32,8)
grp_2 = torch.rand(32,8)
torch.stack([grp_1,grp_2],dim=0).shape
torch.Size([2, 32, 8])
# split
在 0 维之前添加一个用于区分 tensor 的维度
b = torch.rand(32,8)
b.shape
torch.Size([32, 8])
# 拆分根据长度进行拆分,
# 给定拆分数量,根据数量(num)进行拆分,根据长度(len)进行拆分
c = torch.stack([grp_1,grp_2],dim=0)
c.shape
torch.Size([2, 32, 8])
split
使用 split 可以对 tensor 进行任意拆分,将形状为 (1, 32, 8) 的 tensor 使用 split 进行拆分 1,2 块,而且总和需要和原来 tensor 在维度 0 上数量保持一致。
grp_1,grp_2 = c.split([1,1],dim=0)
grp_1.shape
torch.Size([1, 32, 8])
c = torch.rand(3,32,8)
c.shape
torch.Size([3, 32, 8])
grp_1,grp_2 = c.split([1,2],dim=0)
print(grp_1.shape)
print(grp_2.shape)
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
如果只给一个参数 2 表示拆分为 2 一组 tensor,剩余 tensor 为一组。
grp_1,grp_2 = c.split(2,dim=0)
print(grp_1.shape)
print(grp_2.shape)
torch.Size([2, 32, 8])
torch.Size([1, 32, 8])
split 好处就是我们可以将 tensor 在某一个维度任意分组。
grp_1,grp_2,grp_3 = c.split([1,1,1],dim=0)
print(grp_1.shape)
print(grp_2.shape)
print(grp_3.shape)
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
chunk
chunk 是数量进行拆分,将形状为 (1, 32, 8) 的 tensor 使用 chunk 进行拆分,3 表示将 tensor 在维度 0 上等分 3
# chunk 按数量进行拆分
grp_1,grp_2,grp_3 = c.chunk(3,dim=0)
print(grp_1.shape)
print(grp_2.shape)
print(grp_3.shape)
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
网友评论