美文网首页deeplearningpython
tf.gather和tf.gather_nd的详细用法--ten

tf.gather和tf.gather_nd的详细用法--ten

作者: Daniel开峰 | 来源:发表于2019-03-01 14:07 被阅读0次

    在numpy里取矩阵数据非常方便,比如:

    a = np.random.random((5, 4))
    indices = np.array([0,2,4])
    
    print(a)
    #array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
    #      [0.56551837, 0.27328607, 0.50911876, 0.01179739],
    #       [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
    #      [0.09511502, 0.96345098, 0.6500849 , 0.04084285],
    #       [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
    print(a[indices])
    #array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
    #      [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
    #       [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
    

    这样就把矩阵a中的1,3,5行取出来了。

    如果是只取某一维中单个索引的数据可以直接写成tensor[:, 2], 但如果要提取的索引不连续的话,在tensorflow里面的用法就要用到tf.gather.

    import tensorflow as tf
    sess = tf.Session()
    b = tf.gather(tf.constant(a), indices)                                                                                              
    
    sess.run(b)                                                                                                                         
    #Output
    array([[0.47122875, 0.37836802, 0.18210801, 0.341471  ],
           [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
           [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
    
    

    tf.gather_nd允许在多维上进行索引:
    matrix中直接通过坐标取数(索引维度与tensor维度相同):

        indices = [[0, 0], [1, 1]]
        params = [['a', 'b'], ['c', 'd']]
        output = ['a', 'd']
    

    取第二行和第一行:

        indices = [[1], [0]]
        params = [['a', 'b'], ['c', 'd']]
        output = [['c', 'd'], ['a', 'b']]
    

    3维tensor的结果:

        indices = [[1]]
        params = [[['a0', 'b0'], ['c0', 'd0']],
                  [['a1', 'b1'], ['c1', 'd1']]]
        output = [[['a1', 'b1'], ['c1', 'd1']]]
    
    
        indices = [[0, 1], [1, 0]]
        params = [[['a0', 'b0'], ['c0', 'd0']],
                  [['a1', 'b1'], ['c1', 'd1']]]
        output = [['c0', 'd0'], ['a1', 'b1']]
    

    另外还有tf.batch_gather的用法如下:
    tf.batch_gather(params, indices, name=None)
    Gather slices from params according to indices with leading batch dims.

    This operation assumes that the leading dimensions of indices are dense,
    and the gathers on the axis corresponding to the last dimension of indices.

    #tf.batch_gather按如下运算:
    result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]
    

    Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
    indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be
    a Tensor of size [A1, ..., AN-1, C, B1, ..., BM].

    如果索引是一维的tensor,结果和tf.gather 是一样的.

    相关文章

      网友评论

        本文标题:tf.gather和tf.gather_nd的详细用法--ten

        本文链接:https://www.haomeiwen.com/subject/dpnruqtx.html