美文网首页
pyTorch学习笔记——gather函数详解

pyTorch学习笔记——gather函数详解

作者: 韧心222 | 来源:发表于2021-10-13 16:11 被阅读0次

    参考文献:图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)

    gather 函数的声明为:

    torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
    

    这个函数我大概研究了半个小时,虽然明白了基本的运算方法,但是其具体用法还理解的不够深入,如果以后有心得的话,再和大家来交流。

    根据gather函数的声明,可以看到gather函数主要由三个参数,分别是:

    • input
    • dim
    • index

    下面,我结合一个具体的例子,来给出其具体的计算方法:

    import torch
    
    torch.manual_seed(100)
    x=torch.randn(2,3)
    index = torch.LongTensor([[0,1,1]])
    a = torch.gather(x, 0, index)
    

    此时,x的值为:

    tensor([[ 0.3607, -0.2859, -0.3938],
            [ 0.2429, -1.3833, -2.3134]])
    

    a的值为:

    tensor([[ 0.3607, -1.3833, -2.3134]])
    

    下面,我们来看看其具体的计算过程:
    第一步,获得index的index
    index的值为[[0,1,1]],其每个元素对应的index为:
    (0, 0)、(0, 1)、(0, 2)

    第二步,看dim的值
    在调用gather函数的时候,需要指定dim的值,此处我们的dim值为0

    第三步,根据dim用index的值来替换第一步中得到的对应维度的值
    因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第一个维度的值,替换后的值为:
    (0, 0)、(1, 1)、(1, 2)

    第四步,根据第三步得到的新的索引值,在input中进行取值
    input中的(0, 0)、(1, 1)、(1, 2)分别对应了 0.3607、 -1.3833、 -2.3134,而这正是我们得到的计算结果。

    相关文章

      网友评论

          本文标题:pyTorch学习笔记——gather函数详解

          本文链接:https://www.haomeiwen.com/subject/qmfeoltx.html