Cat Stack 解析
import torch
cat
cat 不会改变维度,是将多个 tensor 按照指定维度链接
x = torch.tensor([1, 2, 3])
print(torch.cat((x, x), dim=0))
print(torch.cat((x, x), dim=-1))
xx = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(torch.cat((xx, xx), dim=0))
print(torch.cat((xx, xx), dim=-1))
tensor([1, 2, 3, 1, 2, 3])
tensor([1, 2, 3, 1, 2, 3])
tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
stack
stack 按照指定维度将 多个 x 按照指定方向,并且增加一个维度放在一起。
stack_x = torch.stack((x, x))
stack_x
tensor([[1, 2, 3],
[1, 2, 3]])
torch.stack((x, x), dim=-1)
tensor([[1, 1],
[2, 2],
[3, 3]])
torch.stack((x, x), dim=0)
tensor([[1, 2, 3],
[1, 2, 3]])
网友评论