美文网首页
[tf]tf.gather_nd的用法

[tf]tf.gather_nd的用法

作者: VanJordan | 来源:发表于2019-02-17 22:28 被阅读0次

函数原型,nd的意思是可以收集n dimensiontensor

tf.gather_nd(
    params,
    indices,
    name=None
)
  • 意思是要收集[params[0][0],params[1][1]]
    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']
  • 意思是要收集[params[1],params[0]]
    indices = [[1], [0]]
    params = [['a', 'b'], ['c', 'd']]
    output = [['c', 'd'], ['a', 'b']]
  • 意思是要收集[params[1]]
    indices = [[1]]
    params = [[['a0', 'b0'], ['c0', 'd0']],
              [['a1', 'b1'], ['c1', 'd1']]]
    output = [[['a1', 'b1'], ['c1', 'd1']]]
  • 我们使用这个函数的一般是想完成这样一个功能:T是一个二维tensor,我们想要根据另外一个二维tensor value的最后一维最大元素的下标选出tensor T 中最后一维最大的元素,组成一个新的一维的tensor,那么就可以首先选出最后一维度的下标[1,2,3],然后将将其扩展成[[0,1],[1,2],[2,3]],然后使用这个函数选择即可。
max_indicies = tf.argmax(T, 1)
import tensorflow as tf

sess = tf.InteractiveSession()

values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])

T = tf.constant([[0, 1, 2 ,  3],
                 [4, 5, 6 ,  7],
                 [8, 9, 10, 11]])

max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
                                            dtype=max_indices.dtype),
                                   max_indices),
                                  axis=1))
print(result.eval())

相关文章

网友评论

      本文标题:[tf]tf.gather_nd的用法

      本文链接:https://www.haomeiwen.com/subject/knineqtx.html