美文网首页
Tensorflow的几个常见函数(1)

Tensorflow的几个常见函数(1)

作者: YANWeichuan | 来源:发表于2018-12-27 15:19 被阅读0次
  • tf.argmax: 返回tensor中行或者列的最大值的下标。通过axis只能按行还是按列
    如2行3列矩阵:
    [[1, 5, 9]
    [8, 6, 2]]
    tf.argmax(c, axis=0)按照列比较返回:[1 1 0],第一列8大下标1,第二列6大下标1,第三列9大下标0。
    tf.argmax(c, axis=1)按照行比较返回:[2 0],第一行9最大下标2,第二行8最大下标0
  • tf.reduce_mean:按照tensor中的行或者列求平均值,通过reduction_indices指定按行还是按列
    如2行3列矩阵:
    [[1, 5, 9]
    [8, 6, 2]]
    tf.reduce_mean(c, reduction_indices = 0)按照列求平均值返回一个三列[4.5 5.5 5.5]值
    tf.reduce_mean(c, reduction_indices = 1)按照行求平均值返回一个两列[5. 5.3333335]值
  • tf.reduce_sum:同reduce_mean类似,按照tensor中的行或者列求和,通过reduction_indices指定按行还是按列
  • tf.equal:对两个tensor中对每个对应的元素进行比较,返回true/false新的tensor
  • tf.cast:转换tensor到指定的类型

示例代码:

import tensorflow as tf

a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]], tf.int32)
b = tf.constant([[9, 8, 7], [4, 5, 6], [3, 2, 1]], tf.int32)
c = tf.constant([[1, 5, 9], [8, 6, 2]], tf.float32)

with tf.Session() as sess:
    print(sess.run(c))
    # tf.argmax
    print("tf.argmax:")
    print(sess.run(tf.argmax(c, axis=0)))
    print(sess.run(tf.argmax(c, axis=1)))

    # tf.reduce_mean
    print("tf.reduce_mean:")
    print(sess.run(tf.reduce_mean(c, reduction_indices = 0)))
    print(sess.run(tf.reduce_mean(c, reduction_indices = 1)))

    # tf.reduct_sum
    print("tf.reduce_sum:")
    print(sess.run(tf.reduce_sum(c, reduction_indices = 0)))
    print(sess.run(tf.reduce_sum(c, reduction_indices = 1)))

    # tf.equal
    print("tf.equal:")
    print(sess.run(tf.equal(a,b)))

    # tf.cast
    print("tf.cast:")
    print(sess.run(tf.cast(tf.equal(a, b), tf.float32)))

输出:

[[1. 5. 9.]
 [8. 6. 2.]]
tf.argmax:
[1 1 0]
[2 0]
tf.reduce_mean:
[4.5 5.5 5.5]
[5.        5.3333335]
tf.reduce_sum:
[ 9. 11. 11.]
[15. 16.]
tf.equal:
[[False False False]
 [ True  True  True]
 [False False False]]
tf.cast:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

相关文章

网友评论

      本文标题:Tensorflow的几个常见函数(1)

      本文链接:https://www.haomeiwen.com/subject/ncwrlqtx.html