torch.index_select(input, dim, index, *, out=None) → Tensor
作用是:
Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
返回按照相应维度的给定index的选取的元素,index必须是longtensor。
-
按照dim=0时, x_0应该是在4块中选取第0和第1块赋给x_0.
-
dim=1时,应该是3行中选取第1和第2列给x_1.
x = torch.rand((4, 3, 2))
print('x:',x)
indices0 = torch.LongTensor([0, 1])
x_0 = torch.index_select(x, dim=0, index=indices0)
print("x_0:", x_0)
indices1 = torch.LongTensor([1, 2])
x_1 = torch.index_select(x, dim=1, index=indices1)
print("x_1:", x_1)
++++++++++++++++++++++++++++++++++++++++++
x: tensor([[[0.9854, 0.4894],
[0.3774, 0.6066],
[0.5971, 0.7116]],
[[0.0447, 0.9854],
[0.6996, 0.1671],
[0.4965, 0.5742]],
[[0.9878, 0.9571],
[0.9090, 0.5475],
[0.6792, 0.4184]],
[[0.2394, 0.9625],
[0.1951, 0.2918],
[0.3154, 0.2175]]])
x_0: tensor([[[0.9854, 0.4894],
[0.3774, 0.6066],
[0.5971, 0.7116]],
[[0.0447, 0.9854],
[0.6996, 0.1671],
[0.4965, 0.5742]]])
x_1: tensor([[[0.3774, 0.6066],
[0.5971, 0.7116]],
[[0.6996, 0.1671],
[0.4965, 0.5742]],
[[0.9090, 0.5475],
[0.6792, 0.4184]],
[[0.1951, 0.2918],
[0.3154, 0.2175]]])
网友评论