美文网首页
Tensorflow API Cross Entropy and

Tensorflow API Cross Entropy and

作者: gritsasa15 | 来源:发表于2019-04-12 11:24 被阅读0次

关于tf.nn.softmax_cross_entropy_with_logits()/tf.losses.softmax_cross_entropy()/tf.losses.softmax_cross_entropy()区别和联系:

logit_list = np.array([[1, 2, 3],
                       [5, 6, 4],
                       [7, 8, 9],
                       [2, 1, 0],
                       [3, 4, 5]], dtype=np.float32)

# onehot for [2, 1, 0, 0, 2]
onehot_list = np.array([[0, 0, 1],
                        [0, 1, 0],
                        [1, 0, 0],
                        [1, 0, 0],
                        [0, 0, 1]], dtype=np.float32)

res1 = tf.nn.softmax_cross_entropy_with_logits(labels=onehot_list, logits=logit_list)
res2 = tf.losses.softmax_cross_entropy(onehot_labels=onehot_list, logits=logit_list)
res3 = tf.losses.softmax_cross_entropy(onehot_labels=onehot_list, logits=logit_list, weights=0.2)
res4 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(onehot_list, axis=1), logits=logit_list)
res5 = tf.losses.sparse_softmax_cross_entropy(labels=tf.argmax(onehot_list, axis=1), logits=logit_list)

sess = tf.Session()
print(sess.run(res1))                           # [0.40760595 0.40760595 2.407606   0.40760595 0.40760595]
print(sess.run(res2))                           # 0.8076059
print(sess.run(tf.reduce_mean(res1)))           # 0.8076059
print(sess.run(res3))                           # 0.1615212
print(sess.run(0.2 * tf.reduce_mean(res1)))     # 0.16152118
print(sess.run(res4))                           # [0.40760595 0.40760595 2.407606   0.40760595 0.40760595]

tf.losses.softmax_cross_entropy(): 用tf.nn.softmax_cross_entropy_with_logits()实现。
tf.losses.sparse_softmax_cross_entropy(): 用tf.nn.sparse_softmax_cross_entropy_with_logits()实现。
tf.losses.sparse_softmax_cross_entropy() 等价于tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits),输入labels是非one-hot编码格式。

相关文章

网友评论

      本文标题:Tensorflow API Cross Entropy and

      本文链接:https://www.haomeiwen.com/subject/konqwqtx.html