美文网首页
UDA 数据增强及复现结果- Unsupervised Data

UDA 数据增强及复现结果- Unsupervised Data

作者: 锦绣拾年 | 来源:发表于2021-05-23 09:50 被阅读0次

Unsupervised Data Augmentation for Consistency Training 概述
Unsupervised Data Augmentation for Consistency Training 是Google团队在2020年提出的一种数据增强的方法,即如何在小规模标注数据集+大规模无标签数据集上,实现较好的训练结果。

模型概述

模型的重点在于对loss的改进:


模型.png

最后的loss组成:

  • 正常的有监督交叉熵

  • 一致性损失。给定x,计算输出分布p_{\theta}(y|x)和一个噪音版本的输出p_{\theta}(y|x,\epsilon) 通过注入微弱的噪音\epsilon。噪音可以施加到x和输入层上。最小化两个分布间的(KL)散度 D( (p_{\theta}(y|x))||(p_{\theta}(y|x,\epsilon)) )

使用权重因子\lambda 平衡有监督交叉熵和无监督训练损失。整体训练目标如下:

\underset{\theta}{min}\mathcal{J(\theta)}=E_{x1 \sim p_{L(x)}}[-logp_{\theta}(f^*(x_1)|x_1)]+\lambda E_{x_2\sim p_{U(x)}}E_{\hat{x}\sim q(\hat{x}|x_2)}[CE(p_{\tilde{\theta}}(y|x_2)||p_{\theta}(y|\hat{x}))]

$$ \underset{\theta}{min}\mathcal{J(\theta)}=E_{x1 \sim p_{L(x)}}[-logp_{\theta}(f^*(x_1)|x_1)]+\lambda E_{x_2\sim p_{U(x)}}E_{\hat{x}\sim q(\hat{x}|x_2)}[CE(p_{\tilde{\theta}}(y|x_2)||p_{\theta}(y|\hat{x}))] $$

加号前的公式是正常的交叉熵。

加号后,q(\hat{x}|x)是数据增强变换。\hat{\theta}是当前参数的严格拷贝。

在大多数实验中\lambda设为了1。
在每一轮迭代中,在一个mini-batch的有监督数据中计算有监督学习的损失,一个mini-batch的无监督数据中计算无监督学习的一致性损失,后者的batch_size设置地大一点。

训练技巧

一致性损失

1.Confidence-based masking.

除去置信度较低的example。在每个minibatch中,一致性损失值计算最高概率大于\beta的例子。 Specifically, β is set to 0.8 for CIFAR-10 and SVHN and 0.5 for ImageNet.(在二分类中,所有的例子都满足>0.5,即正例<0.5,负例一定>0.5)

2.Sharpening Predictions

使用low softmax temperature \tau来锐化1中选择出的结果。(结果比上\tau后再经过softmax)

一致性损失改进

有监督部分

TSA Training Signal Annealing for Low-data Regime方法

标注数据和未标注数据之间的数量差可能很大,因此模型极易很快在标注数据上过拟合,而还未在未标注数据上拟合。

仅使用模型置信度小于给定阈值的标签数据。即如果正确分类的p_{\theta}(y^*|x)比给定阈值\eta_t高,我们从损失函数中移除它。假设K是类别数,\eta的值会逐渐由\frac{1}{K}增长到1。防止对任意标签的例子过度训练。

如果训练数据较少或者问题简单, exp-schedule比较合适。

如果模型不容易过拟合(标签数据多或者模型有好的正则化),log-schedule更好。


TSA

代码

模型pytorch代码可参考:
https://github.com/SanghunYun/UDA_pytorch/blob/master/main.py
https://github.com/zhoujx4/NLP-Series-text-cls/blob/main/run_uda.py

复现结果

复现数据,使用IMDB数据,训练集20个,无监督数据6w+
复现参数:有监督bs=8 无监督bs=24 λ=1.5
复现时遇到两个问题:
1)遇到unsup_loss不下降的问题
发现代码中两个问题,1是全部训练步数计算过小,全部训练步数用于计算TSA。
2是ori_logit没有放到 no_grad里,但是计算ori_logit前面步骤都在,而且ori_logit中间打印出来是没有梯度的。(很奇怪)改完这两个后有下降的趋势了。
2)遇到val_loss上升,val_acc也上升的情况,但是使用数据集,val_acc最高只有57%左右。
发现是max_len设置过小导致。IMDB数据集中,词汇数<128的只占17%左右,<512占90%,因此max_len设置512。

以下是复现结果,只训练了23k步左右,最高val_acc为86%左右:


final_loss.png sup_loss.png unsup_loss.png

每个epoch用验证集验证5次,因此实际训练步数是val_index ×(len(unsup)/5) ≈ val_index × 500


val_acc.png val_loss.png

相关文章

网友评论

      本文标题:UDA 数据增强及复现结果- Unsupervised Data

      本文链接:https://www.haomeiwen.com/subject/jgvajltx.html