Label smoothing是机器学习中的一种正则化方法,其全称是 Label Smoothing Regularization(LSR),即标签平滑正则化。其应用场景必须具备以下几个要素:
- 标签是one-hot向量;
- 损失函数是交叉熵损失函数。
其作用对象是真实标签,如果将其视为一个函数,即 LSR(grandtruth_targets)。
为了说明这个方法的作用,我们追溯一下真实标签的使用场景。
在神经网络训练中,真实标签主要用于两个方面:1)计算loss; 2)计算accuracy。
计算accuracy时只拿真实标签值和预测索引值做比较,如果二者相等,则说明预测准确,此时真实标签并不参与计算。
计算loss时,用到的交叉熵损失函数如下:
其中就是本文的主角——真实标签,其为one-hot向量形式,是经softmax层输出的预测概率。可以看出,损失函数只与预测值和真实值有关,所以此处真实值的形式和数值大小对损失函数的计算影响非常大。
常用的真实标签的形式是one-hot向量,其值非0即1,计算Loss时如果是1,Loss中该项的就得到保留;如果是0,该项的就彻底抛弃。即:
这样就使得模型过分相信标签的标注,只要是标签为1的项就保留,标签为0的项就统统抛弃,那万一标签标注错了岂不是错杀好人?标签是人为标注的,如果人在标注的时候一个不留神标错了,而模型的判决又“充分”信任人为标注的标签,它作为模型判断保留或丢弃某项数据的标准,这就会使错误标签在模型训练中产生较大的影响。
标签平滑做的工作就是使真实标签不那么极端化,给予标签一定的容错概率。思想如下:
对照上面计算Loss的交叉熵公式,真实标签可视作加在预测值上的权重,也即模型用来评判保留或丢弃某项预测信息的概率,按one-hot形式的标签,模型只保留标签为1对应的中的值。
我们假设标签有的概率标注出错,即某个标签为真的概率为,那其它为假的项也有的概率是真,所以计算Loss时原来标签为0的值对应的项不应该完全抹杀,也要以的概率保留下来,即。
在多分类任务中(假设有n类),我们用来标记真实类别; 其余n-1项共同拥有的概率标记出错,这里假设这n-1项服从均匀分布,即标记出错的概率与样本类别无关,所以每一项出错的概率为。
此时的标签形式为:
即用替换0,用代替1。其实在实际的实现中,分母直接采用类别数量n,而不是n-1。
Loss为:
这样,就使得预测的所有信息都得到一定程度的保留,提高了模型对标记错误的数据的容忍度,提高了模型的泛化能力。
以下举一个例子阐述个人的理解,如果出错还望指出。
在模型训练时,标签在引导模型的预测不断接近自己的标准,即如果预测值和标签一致,loss就会减小,否则增大。Loss变大则梯度变大,梯度变大则参数更新幅度变大。举个例子:
如果模型的预测值为: [0.1 0.6 0.2 0.1]
而标注的one-hot标签为:[0 1 0 0]
此时
采用LSR时假设,则标签为[]
此时
如果标签错误标注成了[0 0 1 0],此时模型的预测其实是对的,只是标签被人为标注错了,
此时
采用LSR时的标签为[]
则
在这种情况下,如果采用one-hot标签,那么两种情况的Loss差距为 1.61-0.51=1.1
如果采用LSR标签,则两种情况的差距为 1.62-0.67=0.95
这种情况通俗的讲就是,在one-hot形式下,本来参数应该前进5步,却前进了16步,多前进了11步,离最优值距离进一步拉大。
在LSR形式下,参数本应前进7步,却因为一个人为的错误前进了16步,多前进了9步,虽然离最优值的距离也拉大了,但相比one-hot,它犯错误的程度减小了,就好比同样是做错题,一个做错了11道,一个做错了9道。即LSR将错误标签对模型训练产生的反面影响降低了,进而提高了模型泛化能力。
在应用中,的常用值为0.1
其代码实现如下:
def label_smoothing(inputs, epsilon=0.1):
K = inputs.get_shape().as_list()[-1] # number of channels
return ((1-epsilon) * inputs) + (epsilon / K)
输入参数inputs是one-hot形式的真实标签。
在tensorflow中,标签平滑已经封装在了
tf.losses.softmax_cross_entropy(onehot_labels=new_onehot_labels, logits=out, label_smoothing=0)
函数中,设置其中的label_smoothing的值()即可。
Pytorch中标签平滑的应用场景不多,需要自行构建函数,因为pytorch中常采用的标签就是0,1,2,3,4……形式的,并非one-hot向量,计算loss时常用的损失函数的输入值也直接是网络的预测值(logits)和0,1,2,3,4……形式的真实值标签,而不是softmax输出值和one-hot标签,损失函数内部会自动将标签转换为one-hot向量进行计算。pytorch实现LSR的方法待更……
网友评论