定义
def gather_nd(params, indices, name=None)
功能
根据indeces描述的索引,在params中提取元素,重新组成一个tansor
举例
![](https://img.haomeiwen.com/i12895795/7f5002218f00c4ef.png)
data shape is (3, 2, 3)
data rank is 3
indices = np.array([[0, 1], [1, 0]])
indices shape is (2, 2)
最后的切片的结果是indices中表示索引的部分被提取到的值替换后得到的结构。
以上面的例子说明这个思路:
![](https://img.haomeiwen.com/i12895795/6a266d22f87f95c4.png)
[0, 1]索引得到[2, 2, 2]
[1, 0]索引得到[3, 3, 3]
把索引的结果替换到indices中得到:[[2, 2, 2], [3, 3, 3]]
当索引indices为 [[[[1,1]]]]时,
先找出[1, 1]的索引结果为[4,4,4]
替换到上面结构中得到 [[[[4, 4, 4]]]]
举例
nn_pts = tf.gather_nd(pts, indices, name=tag + 'nn_pts') # (N, P, K, 3)
其中:
nn_pts.shape is (32, 1024, 3)
indices.shape is (32, 512, 32, 2)
output.shape is (32, 512, 32, 3)
已知 nn_pts的最小component是某个点的坐标(x, y, z),即 3 代表的含义。
indices的最小component是(a, b), 就是说要取nn_pts第0维的第 a 个,第1维的第b个值,取出来的这个值是一个point。
一共取了(32, 512, 32)这么多的point,所以最后的output形状为(32, 512, 32, 3)
网友评论