美文网首页
TensorFlow(5)常用函数

TensorFlow(5)常用函数

作者: 操作系统 | 来源:发表于2017-08-05 09:41 被阅读0次

    tf.argmax(actv,1)

    tf.argmax(input, axis=None, name=None, dimension=None)
    此函数是对矩阵按行或列计算最大值
    参数:

    • input:输入Tensor
    • axis:0表示按列,1表示按行
    • name:名称
    • dimension:和axis功能一样,默认axis取值优先。新加的字段
      返回:Tensor 一般是行或列的最大值下标向量
    import tensorflow as tf    
    a=tf.get_variable(name='a',  
                      shape=[3,4],  
                      dtype=tf.float32,  
                      initializer=tf.random_uniform_initializer(minval=-1,maxval=1))  
    b=tf.argmax(input=a,axis=0)  
    c=tf.argmax(input=a,dimension=1)   #此处用dimesion或用axis是一样的  
    sess = tf.InteractiveSession()  
    sess.run(tf.initialize_all_variables())  
    print(sess.run(a))  
    #[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]  
    # [ 0.18663144  0.86972666 -0.06103253  0.38307118]  
    # [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]  
    print(sess.run(b))  
    #[2 1 1 2]  
    print(sess.run(c))  
    #[0 1 0]  
    

    tf.cast(x, dtype, name=None)

    将x的数据格式转化成dtype
    例如,原来x的数据格式是bool, 那么将其转化成float以后,就能够将其转化成0和1的序列。反之也可以。

    a = tf.Variable([1,0,0,1,1])
    b = tf.cast(a,dtype=tf.bool)
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    print(sess.run(b))
    #[ True False False  True  True]
    c = tf.cast(b,dtype=tf.float32)
    print(sess.run(c))
    #[1,0,0,1,1]
    sess.close()
    

    tf.equal

    tf.equal(A, B)是对比这两个矩阵或者向量的相等的元素,如果是相等的那就返回True,反正返回False,返回的值的矩阵维度和A是一样的

    import tensorflow as tf  
    import numpy as np  
    A = [[1,3,4,5,6]]  
    B = [[1,3,4,3,2]]  
    with tf.Session() as sess:  
        print(sess.run(tf.equal(A, B)))  
    

    tf.reduce_mean()

    tf.reduce_mean(input_tensor, reduction_indices=None, keep_dims=False, name=None)

    import tensorflow as tf
    x = tf.constant([[1,2,3],[4,5,6]],dtype=tf.float32)
    a = tf.reduce_mean(x)  
    # 3.5
    b = tf.reduce_mean(x, 0) 
    # [2.5, 3.5, 4.5]
    c = tf.reduce_mean(x, 1) 
    # [2.,  5.]
    sess = tf.Session()
    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(c))
    

    相关文章

      网友评论

          本文标题:TensorFlow(5)常用函数

          本文链接:https://www.haomeiwen.com/subject/amqblxtx.html