美文网首页
Index Select 和 Gather 解析

Index Select 和 Gather 解析

作者: 潘旭 | 来源:发表于2020-07-04 10:53 被阅读0次

    两者都是可以通过 index 从 tensor 中将 value 提取出来。Gather的功能比index select 更强大,包含了 index select 的能力。

    index select

    使用 index 从 tensor 中获取数据。但是,index 只能是 1-D, 这也就是说对于2维或者更多维的,只能获取的是同样 index 的数据。

    import torch
    
    def show_index_select():
        x = torch.tensor([[1,2],[3,4]])
    
        index = torch.tensor([1, 0])
    
        y = torch.index_select(x, dim=-1, index=index)
        return y
    show_index_select()
    
    tensor([[2, 1],
            [4, 3]])
    
    def show_index_select_2():
        x = torch.tensor([1,2,3,4])
    
        index = torch.tensor([1, 0])
    
        y = torch.index_select(x, dim=-1, index=index)
        return y
    show_index_select_2()
    
    tensor([2, 1])
    

    那么,如果我想第一行中选择 [0, 0] 而 第二行中选择 [1, 0],这种不同的 index, 那么, torch.index_select 就无法满足这个需求了。这时候,就需要使用更加灵活的 gather

    gather

    gatherindex_select 可以传入相同的参数,但是 gather 的 index 可以是多维的,也就是说,可以每一行都是不同的。

    def show_gather():
        x = torch.tensor([[1,2],[3,4]])
    
        index = torch.tensor([[0, 0], [1, 0]])
    
        y = torch.gather(x, dim=-1, index=index)
        return y
    
    show_gather()
    
    tensor([[1, 1],
            [4, 3]])
    
    def show_gather_same_index_select():
        x = torch.tensor([[1,2],[3,4]])
    
        index = torch.tensor([[1, 0], [1, 0]])
    
        y = torch.gather(x, dim=-1, index=index)
        return y
    
    show_gather_same_index_select()
    
    tensor([[2, 1],
            [4, 3]])
    

    应用场景

    这两个函数有着非常广泛的应用。比如给出 label index, 那么获取每个label 的value, shape = (BatchSize, SeqLen), 用 gather 来通过 label_index 获取 value. index_select 可以用来获取 span.

    相关文章

      网友评论

          本文标题:Index Select 和 Gather 解析

          本文链接:https://www.haomeiwen.com/subject/detuqktx.html