两者都是可以通过 index 从 tensor 中将 value 提取出来。Gather的功能比index select 更强大,包含了 index select 的能力。
index select
使用 index 从 tensor 中获取数据。但是,index 只能是 1-D, 这也就是说对于2维或者更多维的,只能获取的是同样 index 的数据。
import torch
def show_index_select():
x = torch.tensor([[1,2],[3,4]])
index = torch.tensor([1, 0])
y = torch.index_select(x, dim=-1, index=index)
return y
show_index_select()
tensor([[2, 1],
[4, 3]])
def show_index_select_2():
x = torch.tensor([1,2,3,4])
index = torch.tensor([1, 0])
y = torch.index_select(x, dim=-1, index=index)
return y
show_index_select_2()
tensor([2, 1])
那么,如果我想第一行中选择 而 第二行中选择 ,这种不同的 index, 那么, torch.index_select
就无法满足这个需求了。这时候,就需要使用更加灵活的 gather
gather
gather
与 index_select
可以传入相同的参数,但是 gather
的 index 可以是多维的,也就是说,可以每一行都是不同的。
def show_gather():
x = torch.tensor([[1,2],[3,4]])
index = torch.tensor([[0, 0], [1, 0]])
y = torch.gather(x, dim=-1, index=index)
return y
show_gather()
tensor([[1, 1],
[4, 3]])
def show_gather_same_index_select():
x = torch.tensor([[1,2],[3,4]])
index = torch.tensor([[1, 0], [1, 0]])
y = torch.gather(x, dim=-1, index=index)
return y
show_gather_same_index_select()
tensor([[2, 1],
[4, 3]])
应用场景
这两个函数有着非常广泛的应用。比如给出 label index, 那么获取每个label 的value, , 用 gather 来通过 label_index
获取 value. index_select
可以用来获取 span
.
网友评论