gather 函数的声明为:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
这个函数我大概研究了半个小时,虽然明白了基本的运算方法,但是其具体用法还理解的不够深入,如果以后有心得的话,再和大家来交流。
根据gather
函数的声明,可以看到gather函数主要由三个参数,分别是:
- input
- dim
- index
下面,我结合一个具体的例子,来给出其具体的计算方法:
import torch
torch.manual_seed(100)
x=torch.randn(2,3)
index = torch.LongTensor([[0,1,1]])
a = torch.gather(x, 0, index)
此时,x的值为:
tensor([[ 0.3607, -0.2859, -0.3938],
[ 0.2429, -1.3833, -2.3134]])
a的值为:
tensor([[ 0.3607, -1.3833, -2.3134]])
下面,我们来看看其具体的计算过程:
第一步,获得index的index
index的值为[[0,1,1]],其每个元素对应的index为:
(0, 0)、(0, 1)、(0, 2)
第二步,看dim的值
在调用gather
函数的时候,需要指定dim的值,此处我们的dim值为0
第三步,根据dim用index的值来替换第一步中得到的对应维度的值
因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第一个维度的值,替换后的值为:
(0, 0)、(1, 1)、(1, 2)
第四步,根据第三步得到的新的索引值,在input中进行取值
input中的(0, 0)、(1, 1)、(1, 2)分别对应了 0.3607、 -1.3833、 -2.3134,而这正是我们得到的计算结果。
网友评论