1.定义
softmax运算将输出变换为一个合法的概率分布;
对于真实标签,也可以用类别分布表达:
对于样本i,仅样本i的类别的离散数值为1,其余为0.
image.png
为什么不用平方损失函数?
image.png
因此,
改善上述问题的⼀个⽅法是使⽤更适合衡量两个概率分布差异的测量函数。
其中,交叉熵(cross entropy)是⼀个常⽤的衡量⽅法:
image.png
其实,就是熵的定义公式.
image.png
假设训练数据集的样本数为n,交叉熵损失函数定义为
image.png
最小化交叉熵损失函数等价于最⼤化训练数据集所有标签类别的联合预测概率
image.png
2.交叉熵损失函数的实现
为了得到标签的预测概率,我们可以使⽤pick函数。
# y是两个样本的标签类别,分别是0,2
y = nd.array([0, 2], dtype='int32')
# y_hat是两个样本在3个类别的预测概率
y_hat = nd.array([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(nd.pick(y_hat, y))
image.png
第一个样本,0类别的预测概率是0.1;
第二个样本,2类别的预测概率是0.5.
交叉熵损失函数:
def cross_entropy(y_hat, y):
return -nd.pick(y_hat, y).log()
参考:
动手学深度学习
网友评论