美文网首页
torch.gather

torch.gather

作者: 菌子甚毒 | 来源:发表于2022-07-02 19:46 被阅读0次

    https://pytorch.org/docs/stable/generated/torch.gather.html

    一个简单的例子:

    t = torch.rand(2,3)
    """
    tensor([[0.8133, 0.5586, 0.7917],
            [0.0551, 0.2322, 0.9087]])
    """
    t.gather(dim=0,index=torch.tensor([[0,1,0],[1,0,1]]))
    """
    tensor([[0.8133, 0.2322, 0.7917],
            [0.0551, 0.5586, 0.9087]])
    """
    
    • dim = 0,说明index中所有索引均是索引行。
    • 关于index的shape:dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
    # 常用于以下需求:
    # celoss = torch.tensor([i_s[i_t] for i_s,i_t in zip(softmax,target)])
    
    input = torch.randn(3, 5, requires_grad=True) # (3,5)
    
    n_samples = input.shape[0] # 注意dim=1时,input.shape[0]=index.shape[0], 同理, 可推dim=0时,input.shape[1]=index.shape[1]
    channel = 6
    
    idx = torch.randint(low=0,high=5,size=(n_samples*channel,)).reshape(n_samples,channel)
    """
    tensor([[0, 0, 4, 2, 3, 1], 第一行取第0个,第0个,第4个...
            [3, 3, 1, 0, 2, 2], 第二行取第3个,第3个,第1个...
            [4, 4, 4, 2, 1, 3]]) ...
    """
    input.gather(dim=1,index=idx) # torch.Size([3, 6])
    

    相关文章

      网友评论

          本文标题:torch.gather

          本文链接:https://www.haomeiwen.com/subject/qwdkbrtx.html