美文网首页
2. Pytorch中torch.index_select

2. Pytorch中torch.index_select

作者: yoyo9999 | 来源:发表于2021-03-22 21:54 被阅读0次

torch.index_select(input, dim, index, *, out=None) → Tensor

作用是:

Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.

返回按照相应维度的给定index的选取的元素,index必须是longtensor。

  • 按照dim=0时, x_0应该是在4块中选取第0和第1块赋给x_0.

  • dim=1时,应该是3行中选取第1和第2列给x_1.

x = torch.rand((4, 3, 2))
print('x:',x)
indices0 = torch.LongTensor([0, 1])
x_0 = torch.index_select(x, dim=0, index=indices0)
print("x_0:", x_0)
indices1 = torch.LongTensor([1, 2])
x_1 = torch.index_select(x, dim=1, index=indices1)
print("x_1:", x_1)

++++++++++++++++++++++++++++++++++++++++++
x: tensor([[[0.9854, 0.4894],
         [0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.0447, 0.9854],
         [0.6996, 0.1671],
         [0.4965, 0.5742]],

        [[0.9878, 0.9571],
         [0.9090, 0.5475],
         [0.6792, 0.4184]],

        [[0.2394, 0.9625],
         [0.1951, 0.2918],
         [0.3154, 0.2175]]])
x_0: tensor([[[0.9854, 0.4894],
         [0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.0447, 0.9854],
         [0.6996, 0.1671],
         [0.4965, 0.5742]]])
x_1: tensor([[[0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.6996, 0.1671],
         [0.4965, 0.5742]],

        [[0.9090, 0.5475],
         [0.6792, 0.4184]],

        [[0.1951, 0.2918],
         [0.3154, 0.2175]]])


相关文章

网友评论

      本文标题:2. Pytorch中torch.index_select

      本文链接:https://www.haomeiwen.com/subject/jlsqhltx.html