美文网首页
PyTorch Gather 函数

PyTorch Gather 函数

作者: 数科每日 | 来源:发表于2022-02-14 23:52 被阅读0次

    PyTorch 的 Gather 函数很实用,但是理解起来有些困难,本文试图用图例和代码给出解释。 完整代码

    Gather 主要有三个参数

    • input: 源数据
    • index: 需要选取的数据的index
    • dim: 筛选数据的方式

    Gather 函数返回值和 index 相同

    Dim=0
    Dim=0
    dim = 0
    input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
    index = torch.tensor([[0, 1, 2], [1, 2, 0]])
                         
    output = torch.gather(input, dim, index)
    output
    

    Dim = 0 的时候, 从外层选择, 最内层的 list Tensor 会被拆开:

    image.png
    Dim=1
    Dim=1
    dim = 1
    input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
    index = torch.tensor([[0, 1], [1, 2], [2, 0]])
    
    output = torch.gather(input, dim, index)
    output
    

    Dim = 1 的时候, 从内层选择:

    image.png

    相关文章

      网友评论

          本文标题:PyTorch Gather 函数

          本文链接:https://www.haomeiwen.com/subject/tmrrlrtx.html