美文网首页深入理解tensorflowTensorFlow
Tensorflow中的tf.argmax()函数

Tensorflow中的tf.argmax()函数

作者: WilloLee | 来源:发表于2017-04-24 20:34 被阅读4984次

    转载请注明出处:http://www.jianshu.com/p/469789141af7

    官方API定义


    tf.argmax(input, axis=None, name=None, dimension=None)

    Returns the index with the largest value across axes of a tensor.
    Args:

    • input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
    • axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
    • name: A name for the operation (optional).

    Returns:

    • A Tensor of type int64.

    关于axis


    定义中的axis与numpy中的axis是一致的,下面通过代码进行解释

    import numpy as np
    import tensorflow as tf
    
    sess = tf.session()
    m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
    print type(m)
    print m
    
    -------------------------------------------------------------------------------
    <type 'numpy.ndarray'>
    [[ 0.09957541 -0.0965599   0.06064715 -0.03011306  0.05533558  0.17263047
      -0.02660419  0.08313394 -0.07225946  0.04916157]
     [ 0.11304571  0.02099175  0.03591062  0.01287777 -0.11302195  0.04822164
      -0.06853487  0.0800944  -0.1155676  -0.01168544]
     [ 0.15760773  0.05613248  0.04839646 -0.0218203   0.02233066  0.00929849
      -0.0942843  -0.05943     0.08726917 -0.059653  ]
     [ 0.02553608  0.07298559 -0.06958302  0.02948747  0.00232073  0.11875584
      -0.08325859 -0.06616175  0.15124641  0.09522969]
     [-0.04616683  0.01816062 -0.10866459 -0.12478453  0.01195056  0.0580056
      -0.08500613  0.00635608 -0.00108647  0.12054099]]
    

    m是一个5行10列的矩阵,类型为numpy.ndarray

    #使用tensorflow中的tf.argmax()
    col_max = sess.run(tf.argmax(m, 0) )  #当axis=0时返回每一列的最大值的位置索引
    print col_max
    
    row_max = sess.run(tf.argmax(m, 1) )  #当axis=1时返回每一行中的最大值的位置索引
    print row_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    
    -------------------------------------------------------------------------------
    #使用numpy中的numpy.argmax
    row_max = m.argmax(0)
    print row_max
    
    col_max = m.argmax(1)
    print col_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    

    可以看到tf.argmax()与numpy.argmax()方法的用法是一致的

    • axis = 0的时候返回每一列最大值的位置索引
    • axis = 1的时候返回每一行最大值的位置索引
    • axis = 2、3、4...,即为多维张量时,同理推断

    参考


    1. Tensorflow官方API tf.argmax说明
    2. Numpy官方AIP numpy.argmax说明

    相关文章

      网友评论

        本文标题:Tensorflow中的tf.argmax()函数

        本文链接:https://www.haomeiwen.com/subject/ewopzttx.html