美文网首页
在解决样本不平衡性问题中不同loss实验对比

在解决样本不平衡性问题中不同loss实验对比

作者: 烛之文 | 来源:发表于2021-06-11 10:29 被阅读0次

在前几期分享中,提到focal_loss,dice_loss,它们都可以解决样本不平衡性的问题。虽然paper上实验结果都验证了其效果,但在自己的数据场景下表现究竟如何。带着这样的问题,本人使用了一个做分类任务的中文数据集,测试cross_entropy、focal_loss、dice_loss,以及带样本权重的weighted_cross_entropy,weighted_focal_loss五类损失函数的表现。项目代码已开源,可查看:modeling-data-imbalance-with-different-losses

About loss

在介绍这几类损失时,用二分类任务来展示它的计算方法,其他类型的任务可按此推广。

对于一训练集合D=\{X, Y\}, x_i为其中一个样本,对应的真实值为y_i=[ y_{i0}, y_{i1}]p_i=[ p_{i0}, p_{i1}]为两个类别的预测概率,其中y_{i0}, y_{i1} \in{0,1}, p_{i0},p_{i1} \in [0,1],前者为取值为0或1,后者取值范围为[0,1],样本数量为N

交叉熵(CE)
在上述定义的二分类任务上,其交叉熵定义为:
CE=- \frac{1}{N}\sum_{i=1}^{N}\sum_{j \in \{0,1\}} y_{ij}logp_{ij}
从公式上可以看出,每个x_i样本对目标函数权重是一样的,这对于imbalanced data来说是不合适的,这就引出带权重的交叉熵损失函数。

带权重交叉熵(WCE)
Weighted \quad CE=- \frac{1}{N}\sum_{i=1}^{N}\lambda_i\sum_{j \in \{0,1\}} y_{ij}logp_{ij}
其中\lambda_i为对应的样本权重,调整不同类型样本下在训练时的权重,一般样本数量越小,权重越大。

Focal Loss
Focal Loss 是一个在交叉熵(CE)基础上改进的损失函数,起初为了解决图像领域中目标检测任务中样本数量不平衡性的问题,以及样本的难易性问题。在应用过程中,关于变量\alpha有两种方案:一种是固定一个值,一种是根据标签种类的对应设置不同的值。
Focal \quad Loss=- \frac{1}{N}\sum_{i=1}^{N}\alpha_i\sum_{j \in \{0,1\}} (1-p_{ij})^{\gamma}logp_{ij}

Dice Loss
Dice Loss原理是与F1值是一致的,Dice Loss是由F1值导向的,用其代替Cross Entropy,可以降低样本不平衡性的影响。
Dice \quad Loss=\frac{1}{N}\sum_{i=1}^{N}(1-\frac{2(1-p_{i1})^\alpha y_{i1}+\gamma}{p_{i1}+y_{i1}+\gamma})

About data

实验用的是一个中文数据集,包含一个train.txt与test.txt文件,对应的样本数量分别为:9804,9832。label的类目数量为20,分布为:['Art', 'Literature', 'Education', 'Philosophy', 'History', 'Space', 'Energy', 'Electronics','Communication', 'Computer','Mine','Transport','Enviornment','Agriculture','Economy','Law','Medical','Military','Politics','Sports'],数据集存在极度样本不平衡性问题。训练集中数据统计详细见下表。

label the number of samples the weight of samples
Art 740 0.66
Literature 33 14.85
Education 59 8.31
Philosophy 44 11.14
History 466 1.05
Space 640 0.77
Energy 32 15.32
Electronics 27 18.16
Communication 25 19.61
Computer 1357 0.36
Mine 33 14.85
Transport 57 8.6
Enviornment 1217 0.4
Agriculture 1021 0.48
Economy 1600 0.31
Law 51 9.61
Medical 51 9.61
Military 74 6.62
Politics 1024 0.48
Sports 1253 0.39

从上面统计可以看出,在训练集中,有些label的样本数量很少,最少为'Communication',只有25个样本,最多为'Economy',有1600样本,呈现样本不平衡问题。样本的权重计算,是采用sklearn中compute_class_weight的balanced计算方法。数据集可以下载,链接,密码: 6yor

About training

部分参数说明,见下面。'normal','focal_loss'两种类型的损失函数有带权重的变体,一共有5种损失函数,详细可见code。此外,在训练中将数据集按random_seed固定随时,并取最后1000条最为验证集。

Hyperparameter value Description
loss_type str('normal','focal_loss','dice_loss') normal指的正常cross_entropy
use_weight bool(True,False) 代表是否要用样本权重进行损失计算
category_weight float(list) 对应各个label的权重值

训练: python run.py train

测试: python run.py test

About experiment

实验对比共5中类型损失,评价的指标有accuracy,precision,recall,f1-score,其训练与测试实验结果如下:

训练结果:

loss accuracy
cross_entropy(normal) 0.956
weight_cross_entropy 0.954
focal_loss 0.955
weight_focal_loss 0.944
dice_loss 0944

测试结果:

loss accuracy precision recall f1-score
cross_entropy(normal) 0.94 0.82 0.71 0.75
weight_cross_entropy 0.94 0.79 0.68 0.72
focal_loss 0.94 0.80 0.71 0.74
weight_focal_loss 0.94 0.80 0.72 0.75
dice_loss 0.94 0.75 0.76 0.75

结果显示,5类损失函数的accuracy值是一样的。因为本实验是分类任务,accuracy指标就具备足够的说服力。这样看来,该几类损失函数在训练效果上差距并不大,而整体来看,cross_entropy最好,不仅形式最为简单,而且最为稳定。使用带权重的方式去训练,对比来看,weight_cross_entropy在F1值上表现最差,weight_focal_loss相比focal_loss有所提升。dice_loss只是在recall指标上有明显提升,但整体没有表现很好的效果。

各个label的F1值测试结果:

label num cross_entropy weight_cross_entropy focal_loss weight_focal_loss dice_loss
Art 741 0.93 0.93 0.93 0.92 0.93
Literature 34 0.14 0.25 0.15 0.31 0.16
Education 61 0.65 0.70 0.72 0.69 0.64
Philosophy 45 0.67 0.69 0.64 0.61 0.46
History 468 0.91 0.92 0.91 0.88 0.89
Space 642 0.96 0.95 0.96 0.95 0.94
Energy 33 0.44 0.29 0.47 0.39 0.48
Electronics 28 0.51 0.28 0.32 0.43 0.50
Communication 27 0.62 0.68 0.67 0.60 0.64
Computer 1358 0.98 0.98 0.98 0.98 0.98
Mine 34 0.68 0.30 0.59 0.77 0.75
Transport 59 0.81 0.71 0.75 0.75 0.72
Enviornment 1218 0.97 0.97 0.97 0.96 0.96
Agriculture 1022 0.95 0.95 0.95 0.95 0.95
Economy 1601 0.94 0.95 0.95 0.95 0.95
Law 52 0.60 0.55 0.53 0.59 0.63
Medical 53 0.75 0.73 0.65 0.73 0.77
Military 76 0.63 0.59 0.67 0.59 0.66
Politics 1026 0.95 0.95 0.95 0.94 0.95
Sports 1254 0.99 0.99 0.99 0.99 0.99

从各个label的F1值来看,并没有那个loss表现的更好。在样本特别少的label(数量<100,有11个)中,相对来说,focal_loss,dice_loss稍微好一些,各自有3个label取得最佳。对比cross_entopy,其他损失函数地区在样本少的label上表现好些,但也不完全绝对,如"Transport"; 在样本多的label上,各个损失表现趋于稳定。

Conclusion

通过本次实验,个人总结有以下几点感受:

1.虽然paper展示出focal_loss,dice_loss表现的多么好,说带权重的损失更适合不平衡样本,但也要看数据集所在的环境,不同场景下,可能表现更槽糕;

2.本次数据集上,虽然focal_loss,dice_loss没有比cross_entropy表现多出色,也就是说并没有有效的解决不平衡性问题,但至少证明了它们跟cross_entropy一样是有效的;

3.样本不平衡也要分情况,如两个label的数量比为100000:1000与10000:100,虽都是相差100倍,但前者第二个label数量更多,更容易让模型去偏向。

4.解决样本不平衡问题,最好办法还是增加训练样本;实在没办法,也可以尝试下不同的损失函数。

Reference

1.Focal Loss for Dense Object Detection

2.Dice Loss for Data-imbalanced NLP Tasks

3.利用Dice Loss来解决NLP任务中样本不平衡性问题

更多文章可关注笔者公众号:自然语言处理算法与实践

相关文章

网友评论

      本文标题:在解决样本不平衡性问题中不同loss实验对比

      本文链接:https://www.haomeiwen.com/subject/cpvmeltx.html