Unsupervised Data Augmentation for Consistency Training 概述
Unsupervised Data Augmentation for Consistency Training 是Google团队在2020年提出的一种数据增强的方法,即如何在小规模标注数据集+大规模无标签数据集上,实现较好的训练结果。
模型概述
模型的重点在于对loss的改进:

最后的loss组成:
-
正常的有监督交叉熵
-
一致性损失。给定x,计算输出分布
和一个噪音版本的输出
通过注入微弱的噪音
。噪音可以施加到x和输入层上。最小化两个分布间的(KL)散度
使用权重因子 平衡有监督交叉熵和无监督训练损失。整体训练目标如下:
$$ \underset{\theta}{min}\mathcal{J(\theta)}=E_{x1 \sim p_{L(x)}}[-logp_{\theta}(f^*(x_1)|x_1)]+\lambda E_{x_2\sim p_{U(x)}}E_{\hat{x}\sim q(\hat{x}|x_2)}[CE(p_{\tilde{\theta}}(y|x_2)||p_{\theta}(y|\hat{x}))] $$
加号前的公式是正常的交叉熵。
加号后,是数据增强变换。
是当前参数的严格拷贝。
在大多数实验中设为了1。
在每一轮迭代中,在一个mini-batch的有监督数据中计算有监督学习的损失,一个mini-batch的无监督数据中计算无监督学习的一致性损失,后者的batch_size设置地大一点。
训练技巧
一致性损失
1.Confidence-based masking.
除去置信度较低的example。在每个minibatch中,一致性损失值计算最高概率大于的例子。 Specifically, β is set to 0.8 for CIFAR-10 and SVHN and 0.5 for ImageNet.(在二分类中,所有的例子都满足>0.5,即正例<0.5,负例一定>0.5)
2.Sharpening Predictions
使用low softmax temperature 来锐化1中选择出的结果。(结果比上
后再经过softmax)

有监督部分
TSA Training Signal Annealing for Low-data Regime方法
标注数据和未标注数据之间的数量差可能很大,因此模型极易很快在标注数据上过拟合,而还未在未标注数据上拟合。
仅使用模型置信度小于给定阈值的标签数据。即如果正确分类的比给定阈值
高,我们从损失函数中移除它。假设K是类别数,
的值会逐渐由
增长到1。防止对任意标签的例子过度训练。
如果训练数据较少或者问题简单, exp-schedule比较合适。
如果模型不容易过拟合(标签数据多或者模型有好的正则化),log-schedule更好。

代码
模型pytorch代码可参考:
https://github.com/SanghunYun/UDA_pytorch/blob/master/main.py
https://github.com/zhoujx4/NLP-Series-text-cls/blob/main/run_uda.py
复现结果
复现数据,使用IMDB数据,训练集20个,无监督数据6w+
复现参数:有监督bs=8 无监督bs=24 λ=1.5
复现时遇到两个问题:
1)遇到unsup_loss不下降的问题
发现代码中两个问题,1是全部训练步数计算过小,全部训练步数用于计算TSA。
2是ori_logit没有放到 no_grad里,但是计算ori_logit前面步骤都在,而且ori_logit中间打印出来是没有梯度的。(很奇怪)改完这两个后有下降的趋势了。
2)遇到val_loss上升,val_acc也上升的情况,但是使用数据集,val_acc最高只有57%左右。
发现是max_len设置过小导致。IMDB数据集中,词汇数<128的只占17%左右,<512占90%,因此max_len设置512。
以下是复现结果,只训练了23k步左右,最高val_acc为86%左右:



每个epoch用验证集验证5次,因此实际训练步数是val_index ×(len(unsup)/5) ≈ val_index × 500


网友评论