1. dim
1. torch.mean()
Pytorch中维度从前往后依次为0, 1,...;反向为-1, -2,...。
如torch.mean(input, dim, keepdim=False, *, out=None) → Tensor
中, 如果dim=-1
, 就是用最后一个维度作为特征,来计算均值。 keepdim=True
代表不进行维度缩减
import torch
x = torch.rand((1, 2, 4, 3))
print(x)
print(x.shape, x.shape[0], x.shape[1], x.shape[-1])
x_mean = x.mean(dim=-1, keepdim=True)
print(x_mean)
###########################################################
tensor([[[[0.9096, 0.3786, 0.2141],
[0.7117, 0.7631, 0.9907],
[0.1136, 0.3949, 0.7348],
[0.4267, 0.6413, 0.2369]],
[[0.3865, 0.5671, 0.6057],
[0.9800, 0.7702, 0.8712],
[0.5558, 0.0272, 0.1465],
[0.0296, 0.2917, 0.4655]]]])
torch.Size([1, 2, 4, 3]) 1 2 3
tensor([[[[0.5008],
[0.8218],
[0.4144],
[0.4350]],
[[0.5198],
[0.8738],
[0.2432],
[0.2623]]]])
2. torch.view(), torch.transpose() 和 torch.reshape()
再来看看 torch.view()
, 哪个维度是-1就是推断出来的维度,其值等于总的维度除已经指定的维度。比如下面例子中,
x_view = x.view([-1, 8])
print(x_view)
print(x_view.shape)
###########################################################
tensor([[0.7176, 0.1339, 0.8430, 0.7403, 0.8586, 0.4462, 0.2916, 0.8481],
[0.6150, 0.2353, 0.8000, 0.8882, 0.1117, 0.3792, 0.6269, 0.0883],
[0.8065, 0.9425, 0.5607, 0.0641, 0.7079, 0.5646, 0.5847, 0.0929]])
torch.Size([3, 8])
torch.transpose(input, dim0, dim1) -> Tensor
是交换轴,比如1, 2就是交换第2个维度到第一个维度。它与view的区别是,会重新拷贝一份,所以内存会改变,不再是同一个tensor。view是在原来的内存上重新view新的形状。所以下面x_view 和 x_trans虽然形状一样,但id不一样,tensor也不相等。
print(x.shape, id(x))
x_trans = x.transpose(1, 2)
print(x_trans)
print(x_trans.shape)
x_view = x.view((1, 4, 2, 3))
print(x_view.shape)
print(id(x_view))
print(torch.equal(x_trans, x_view))
########################################
torch.Size([1, 2, 4, 3]) 1852810812928
tensor([[[[0.1892, 0.3858, 0.4969],
[0.3069, 0.2180, 0.3057]],
[[0.7929, 0.8586, 0.7825],
[0.6899, 0.3481, 0.2375]],
[[0.5666, 0.1618, 0.3767],
[0.4373, 0.0928, 0.4095]],
[[0.7004, 0.7312, 0.6594],
[0.6045, 0.8708, 0.0987]]]])
torch.Size([1, 4, 2, 3])
torch.Size([1, 4, 2, 3])
1852852438080
False
torch.reshape(input, shape) → Tensor
跟view的使用一样, reshape是后面引入的,两者区别在于view要求tensor在内存分布上是contiguous
。什么是 contiguous
看Inference [1]。
x_reshape = x.reshape((-1, 8))
print(x_reshape.shape)
##############################
torch.Size([3, 8])
3. tensor.permute()
tensor.permute()
是按照索引来重新,排列tensor。如下面的从 1, 2, 3, 4
其索引为 0, 1, 2, 3
如果要按照索引 1, 0, 3, 2
排,就变成维度为 2, 1, 3, 4
了.
与torch.transpose()
区别是:
- 它只能是
tensor.permute()
,不能是torch.permute()
。 - 可以交换多个维度,
torch.transpose()
只能交换两个
print(x.shape)
x_per = x.permute((1, 0, 3, 2))
print(x_per.shape)
##########################
torch.Size([1, 2, 4, 3])
torch.Size([2, 1, 3, 4])
4. torch.squeeze()/ unsqueeze()
torch.squeeze(input, dim=None, *, out=None) → Tensor
压缩指定维度中为1的维度, 如果不指定就把所有为1的都压缩掉。
x = x.reshape((1, 2, 1, 3, 4, 1))
print(x.shape)
x_squeeze = x.squeeze()
print(x_squeeze.shape)
x_squeeze_0 = x.squeeze(0)
print(x_squeeze_0.shape)
#################################
torch.Size([1, 2, 1, 3, 4, 1])
torch.Size([2, 3, 4])
torch.Size([2, 1, 3, 4, 1])
而 torch.unsqueeze(input, dim) → Tensor
是在指定维度之前插入1的维度,来满足运算需要。x_squeeze
索引为1的维度为3, 那么在其前面加上1 dim.
print(x_squeeze.shape)
x_uns_1 = x_squeeze.unsqueeze(1)
print(x_uns_1.shape)
###################################
torch.Size([2, 3, 4])
torch.Size([2, 1, 3, 4])
Inference
[1] pytorch-contiguous
[2] Pytorch之permute函数
[3] PyTorch的 transpose、permute、view、reshape
网友评论