PyTorch 的 Gather 函数很实用,但是理解起来有些困难,本文试图用图例和代码给出解释。 完整代码
Gather 主要有三个参数
- input: 源数据
- index: 需要选取的数据的index
- dim: 筛选数据的方式
Gather 函数返回值和 index 相同
Dim=0
Dim=0dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])
output = torch.gather(input, dim, index)
output
Dim = 0 的时候, 从外层选择, 最内层的 list Tensor 会被拆开:
image.pngDim=1
Dim=1dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])
output = torch.gather(input, dim, index)
output
Dim = 1 的时候, 从内层选择:
image.png
网友评论