[参考link]
-1- 话不多说,先上公式
其中:
*
N是样本集的样本量
*
表示第
条样本
*
为当前这条样本的真实label值
*
为模型预测该条样本label为1的概率P
-2- 二分类问题的Loss定义
我们知道,对于二分类问题,样本的标签为0 or 1,模型的预测概率为该条样本取1 label的概率,概率值的范围为。
-2.1- 对于正
样本,模型预测取label为1
的概率
对于正样本,我们希望模型的预测结果越大越好,越接近1.0越好,即
越大越好。
-2.2- 对于负
样本,模型预测取label为0
的概率
则当前样本取label为0的概率为:
对于负样本,我们希望模型的预测结果越小越好,越接近0.0越好,即
越大越好。
-2.3- 对于所有
样本,模型预测取该条样本真实label
的概率
由于y的取值只可能是0或1,经过巧妙的转化,上面两个式子可以改写成:
也即:
*
当真实样本标签为1时,模型预测label取1的概率就是
*
当真实样本标签为0时,模型预测label取0的概率就是
两种情况下概率表达式跟之前的完全一致,只不过我们把模型预测label取1的概率
、模型预测label取0的概率
两种情况整合在一起了,统一为模型预测取该样本真实label的概率
。
-2.4- 模型预测取该样本真实label
的概率P(y|x)
越大越好
现在不用纠结是P(y=1|x)越大越好还是P(y=0|x)越大越好,而是可以很直接地说,我们希望模型预测出样本真实label的概率越大越好,即我们希望的是概率越大越好,而且是对所有的样本整体而言,
越大越好
我们知道一个模型就是由模型结构各单元的参数矩阵来定义,则通过训练我们希望得到的模型就是:
* -2.4.1- 引入log()函数,连乘变连加
首先,我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。而且log运算可以把概率的连乘操作改成连加操作,更适用于计算机计算,则有:
其中,可以拆成:
* -2.4.2- 损失函数
我们希望 P(y|x) 越大越好,反过来,只要取负值之后越小就行了。自然地,我们引入损失函数 即可:
至此就是完整的交叉熵损失函数的推导过程。
-3- 交叉熵损失函数的直观理解
通过以上的推导过程,我们知道了样本集上交叉熵的计算公式:
*
假设,样本集全是正样本,取,可得:
![](https://img.haomeiwen.com/i7715100/3828f3b9f7cbb97f.png)
可以看出,对于一条样本,预测概率越接近1,损失越小。
*
假设,样本集全是负样本,取,可得:
![](https://img.haomeiwen.com/i7715100/c59a3cb33bd4ca14.png)
可以看出,对于一条样本,预测概率越接近0,损失越小。
网友评论