Unsupervised Data Augmentation for Consistency Training
文章地址:https://arxiv.org/pdf/1904.12848v2.pdf
相关代码:https://github.com/google-research/uda
摘要
深度学习的模型训练通常依赖大量的标签数据,在只有少量数据上通常表现不好。在这种背景下, 数据增强应运而生,但是,之前的数据增强相关的研究都是基于监督数据的,并且效果也不是很好。本文提出了一种基于无监督数据的数据增强方式UDA(Unsupervised Data Augmentation)。UDA方法生成无监督数据与原始无监督数据具备分布的一致性,而以前的方法通常只是应用高斯噪声和dropout噪声(无法保证一致性)。UDA方法利用了一种目前为止最优的方法生成更加“真实”的数据。使用这种数据增强方法,在极少量数据集上,六种语言任务和三种视觉任务都得到了明显的提升。例如:IMDb数据分类任务上,仅仅使用20个带标签数据加UDA方法,就超过了25000个带标签数据的训练模型,错误率达到了4.2%。在CIFAR-10 上仅用4000张标签图片就达到了2.7%的错误率。在 SVHN任务上,仅仅用250个标签数据就达到了2.85%的错误率,这相当于用全数据集才能达到的正确率,而它们的数量级差别达到了1或2(差10倍或100倍)。在大量标签数据集上,UDA同样表现优秀,在ImageNet任务上,使用10%带标签数据,UDA方法就将Top1和Top5的准确率分别由55.1%提高到77.3%,68.7提高到88.5%。在全数据集上,则分别由78.3%提高到94.4%,79%提高到94.5%。
1 背景介绍
深度学习通常需要大量的训练数据才能够取得比较好的效果(可看看Bert,XLNet的数据量)。但是,数据标注是一项耗时耗力的工作。所以,充分利用未标注数据是一个很有意义的研究方向。而半监督方法,是最有前景的方法之一,当前半监督方法可归结为三类:(1)基于图卷积和图嵌入的图标签传播方法。(2)将目标数据作为潜变量进行预测。(3)强制一致/平滑。第三种方法在许多任务中被证明具有较好的效果。
简单而言,强制平滑方法只是使得模型对应较小的噪声不那么敏感。常用方法就是:对于一个样本,添加一些噪声(例如高斯噪声)然后强制让模型对于加噪和不加噪的数据的输出尽量的相似。直观而言就是一个好的模型,应该能够适应各种小的、不改变样本性质的扰动。通常由于扰动函数的不同会有各种不同的方案。
本文在Sajjadi、 Laine等人的研究的基础上,从有监督数据中学习扰动函数,从而得到最优的数据增强方法。良好的数据增强方法能够大大提高模型的结果,并且数据增强方法能应用于各领域。本文使用的优化方法是最小化增强数据与真实数据之间的KL散度。虽然有监督数据的数据增强取得了很多成功,但是大量的无监督数据使得UDA这种无监督数据增强方法拥有更广阔前景。
本文主要贡献如下:
(1)提出一种TSA方法,该方法能够在无标签数据大于标签数据的时候防止过拟合。
(2)证明有针对性的数据增强效果明显优于无针对性的数据增强。
(3)验证了本文方法在NLP任务上(例如Bert)上的有效性。
(4)在视觉任务中,本文方法也表现优异。
(5)研究一种能应用于分类数据中有标签数据和无标签数据不匹配情况的方法。
2 无监督数据增强(UDA)
2.1 有监督数据增强
所谓数据增强就是在保持标签一致的情况下,通过某种转换方法扩充出类似于真实数据的训练数据。简单而言就是,有一个样本x,通过转换函数q(x)生成新数据x_a,新旧数据有相同的数据标签。通常为了得到的增强数据与原始数据相似,使用的是最大似然估计方法。
数据增强方法可以看成是从有标签数据中扩充出更多的有标签数据,然后用扩充数据进行模型训练。因此,扩充数据相对于原始数据必须是有效的变换(例如图片缩放对图片识别可能有效,图片旋转可能无效)。也因此,如何设计转换函数至关重要。
目前,针对NLP任务的有监督数据增强方法已经取得了很大进展。虽然有成果,但是它通常被比喻成“蛋糕上的樱桃”,只是提高有限的性能,这是由于监督数据通常都是少量的。因此,本文研究了一种基于大量数据的无监督数据增强方法。
2.2 无监督数据增强
本文研究了一种基于强制平滑方法的改进方法。类似VAT,本文使用最小化无标签数据增广数据和无标签数据的KL散度。公式如下:
公式1 最小化KL散度其中q(x_hat|x)是数据增强变换,θ_hat是当前参数θ的固定副本,表明梯度像Miyato等人所建议的那样,不是通过θ_hat传播的。这里使用的数据增强转换与监督数据增强(如文本的反翻译和图像的随机裁剪)中使用的增强方法相同。由于数据增强耗时比较大,所以数据增强是离线生成的,单个原始样本会生成多个增强样本。
为了同时使用带标签数据和无标签数据,作者在计算带标签数据时上加上交叉熵损失和λ为权重的公式1,最终公式如公式2所示,模型结构如图1所示。
公式2 最终的一致性损失函数通过最小化一致性损失,UDA能够使得标签信息从标签数据引入无标签数据。对于大多数实验,将λ设置为1,并对标签数据和无标签数据使用不同的批处理大小。并发现,在一些数据集上,对无标签数据使用更大的batch会带来更好的性能。
图1 UDA训练模型架构,其中M表示模型相比较于传统的加噪方法,例如:高斯噪声、dropout噪声、或者简单的仿射变换,对不同任务进行针对性的数据增强能够生成更加有效的噪声。它具有以下优点:
(1)扰动的有效性:让增强数据和原始数据使用相同标签在监督学习中取得了良好性能,因此,对于无标签数据的增强也是类似有效的。
(2)扰动的多样性:由于可以对输入数据进行多种方式的改动而不改变数据标签,所以数据增强具有更强的多样性,而例如高斯噪声和贝努力噪声改变了局部信息,因此多样性不好。另外,由于是在一组增强数据集上进行平滑操作,所以数据增强拥有较高的效率。
(3)定向归纳偏差:不同的任务需要不同的归纳偏差。如自动增强,数据增强策略可以直接优化以提高验证性能 每项任务。这种面向性能的增强策略可以学会在原始标记集中找出缺少的或最想要的归纳偏差。虽然自动数据增强策略是应用于监督学习任务中的,但是在本文半监督数据增强中,同样有效。
从后面消融实验可以看出,以上三点是显著提高性能的关键所在。
2.3 不同任务的增强策略
正如在第2.2节中所讨论的,针对不同任务有不同数据增强策略。在本节中,将讨论针对不同任务的三种不同的增强方法,并讨论多样性和有效性的平衡问题。主要内容如下:
(1)用于图像分类的自动增强:对于图像分类,“自动增强”,即使用强化学习来搜索图像增强的“最优”组合,其性能明显优于任何人工设计的优化方法。作者使用已发现的增强策略, 在CIFAR-10, SVHN 和 ImageNet上进行了实验,并在 CIFAR-10, SVHN上组合应用了Cutout技术。
(2)应用于文本分类的反向翻译:反向翻译能够在保存语义不变的情况下,生成多样的句式。实验证明,在QANet上,这种策略取得了良好的效果。因此作者在情感分类问题等数据集,如IMDb, Yelp-2, Yelp-5, Amazon-2 、 Amazon-5上采用了这种策略,同时,他们发现,句式的多样性比有效性更重要。因此采用随机抽样代替集束搜索策略(一种贪心策略)。具体而言,作者使用WMT14语料库来训练英语到法语和法语到英语的翻译模型,并对每个句子执行反翻译,而不是整个段落,因为 WMT14中的并行数据是用于句子级翻译,而情感分类语料库中的输入类型是段落。相关结果如图2所示。
图2 文本数据及图片数据增强样例(3)应用于文本分类的基于TD-IDF的词替代方法:虽然反向翻译能够很好的进行数据扩充,但是它并不能保证扩充的句子包含关键词。而对于某些任务,如DBPedia任务,它的目标是预测某些句子属于维基百科的哪个词条。因此关键字非常重要,本文研究了一种替代TD-IDF分数低的非关键字而保留富含信息词汇的词替代扩充方案,详细见论文附录B。
2.4 数据增强在多样性和有效性上的平衡
虽然在一些非常优秀的数据增强方法中,能够得到很好的多样性和有效性。但是,由于多样性是通过改变原始数据得到的,所以,它存在改变数据类别的风险,所以,多样性和有效性是存在一定矛盾的。
在图像分类中,自动增强算法根据监督设置中的验证集性能进行优化,自动发现了多样性和有效性之间的最佳平衡点。对于文本分类,作者调整随机抽样的强度。当设置为0时,随机抽样解码就会退化为贪婪方法,从而产生最有效的样本。 当设置为1时,生成非常多样但是可能完全无意义的样本(这里感觉作者论文写错了,应该是设置0时,生成的才是多样性的,此处待修正)。研究发现设置Softmax温度控制值为 0.7、0.8或0.9时,效果最好。
3 训练技巧
本节主要介绍一些针对不同问题,不同场景下的训练技巧。
3.1 Training Signal Annealing(TSA)
TSA主要是针对标签数据与未标签数据不平衡时的场景。由于有大量的未标签数据需要UDA处理,所以需要一个较大模型,但是由于较大模型很容易在少量标签数据下过拟合,所以,提出了本方法用于解决该问题,即Training Signal Annealing( 简称TSA)。
TSA的基本原理就是在训练过程中,随着未标签数据的增加,逐步去除带标签数据,从而避免模型过拟合到带标签的训练数据。具体而言,就是在训练的t时刻,设置一个阈值 ηt,当1/K ≤ ηt ≤ 1,其中,K是类别数。 当某个标签数据计算的pθ(y∗| x)大于阈值ηt,就将该标签数据移除出计算损失的过程,而只计算miniBatch里面的其余数据。假定miniBatch样本记作B,那么该策略计算损失如下:
公式3 TSA损失计算方式 公式4 过滤后的样本集合其中,I是指示函数。阈值 ηt 用于防止模型过拟合到标签数据。随着 ηt 向1靠近,模型只能缓慢地从标注的实例中得到监督,大大缓解了过拟合问题。假设T是总训练步数,t是当前的训练步数。为了考虑未标记数据和标记数据的不同比率,有以下三种ηt更新计算方式:
(1)对数形式:ηt 在训练开始阶段增长快,后期慢,公式如下:
公式5 对数形式ηt 更新方式(2)线性形式:ηt 在训练过程线性增长,公式如下:
公式6 线性形式 ηt 更新方式(3)指数形式:ηt 在训练最后阶段增长快,前期慢,公式如下:
公式6 指数形式 ηt 更新方式对于数据量少,容易过拟合的情况,使用指数形式比较好。对于标签数据不容易过拟合的情况,比如标签数据比较多或者使用了有效的正则化手段时,使用对数形式会比较好。
图3 三种形式的ηt 在训练中的变化情况其中,图中设置,ηt 如下:
公式7 ηt对应图3的公式3.2 Sharpening Predictions
当标签数据很少时,未标签数据和预测的未标签数据分布会很平坦。因此,在计算KL散度时,主要贡献的部分来自于标签数据。例如在Imagenet任务中,使用10%标签数据下,未标签数据的分布明显比标签数据的分布更加平坦。而比较丰富的数据分布是比较有利于模型训练的,因此,提出以下三种锐化方案:
(1)基于置信度的mask:对模型预测效果不好的,预测的概率小于一定阈值的标签,不计算一致性损失。
(2)最小化熵:最小化熵就是使得预测的增广数据能够拥有一个较低的熵,因此,需要在计算损失时,加上熵的计算。
(3)Softmax控制:通过调整Softmax控制输出, pθ˜(y|x)通过Softmax(l(X)/τ)计算,其中l(X)表示结果逻辑分布概率,τ表示温度。τ越小,分布越Sharper。
在实验中发现,(1)(3)方案结合使用对于非常少量的标签数据情况很有效,(2)方案适合于标签数据较多的情况。
3.3 Domain-relevance Data Filtering
通常,作者希望能够运用领域外的数据,因为它比较容易获取。但是, 一般领域外的数据和领域内的数据不匹配。 由于数据分布的不匹配,使用领域外的数据往往对模型是有负面影响的。为了获取与当前任务相关的域数据,本文采用一种通用的检测领域外数据的技术。作者用领域内的数据训练了一个模型,让后用它去评估领域外的数据,然后过滤掉置信度低的数据。具体说就是,对于分类任务,对所有领域外数据进行概率计算,只使用其中分类正确且概率高的数据。
4 实验
本文对文本分类和视觉相关任务,运用UDA进行了实验。包括六项文本分类任务和三项图片分类任务。
4.1 文本分类实验
实验是基于Bert进行的,因为它在许多NLP任务中表现都很好。具体实验设置请看原始论文,实验结果如下:
表1 基于Bert的文本分类对比实验结果实验结果表明,运用UDA后,基本都取得了较大的提高。同时,作者还实验了不同标签数据量都结果的影响,结果如下:
图4 不同标签数据下的实验对比结果4.2 与半监督学习方法的对比实验
作者实验对比了UDA与半监督方法,具体实验设置请看原始论文,相关结果如图5。结果显示,UDA结果明显更优。
图5 不同标签数据量下与半监督方式的对比结果同时,作者还对比实验了不同模型的情况,结果如图6。
表2 不同模型下的对比实验结果4.3 ImageNet相关实验
ImageNet之所以要单独拿出来,是因为它是一个很有挑战性的任务,而且数据量很大。作者使用10%标签数据和全数据分别做了对比使用,作者做的实验结果如下:
表3 10%标签数据,图片尺寸224,ImageNet对比实验结果 表4 全数据,图片尺寸331,ImageNet对比实验结果作者对比实验了使用不同训练策略下的情况,结果如下:
表5 TSA对比实验结果最后,作者做了消融实验,对比不同策略的重要性,结果如下:
表6 不同模块的消融实验结果5 相关工作
作者做的一些相关工作,在这里省略了,具体查看原始论文附录。
6 结束语
略,主要内容就是以上内容的概括。
参考文献
部分参考文献如下:
[1] Ben Athiwaratkun, Marc Finzi, Pavel Izmailov, and Andrew Gordon Wilson. There are many consistent explanations of unlabeled data: Why you should average. 2018.
[2] Philip Bachman, Ouais Alsharif, and Doina Precup. Learning with pseudo-ensembles. In Advances in Neural Information Processing Systems, pages 3365–3373, 2014.
[3] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, andColin Raffel. Mixmatch: A holistic approach to semi-supervised learning. arXiv preprintarXiv:1905.02249, 2019.
[4] Olivier Chapelle, Bernhard Scholkopf, and Alexander Zien. Semi-supervised learning (chapelle,o. et al., eds.; 2006)[book reviews]. IEEE Transactions on Neural Networks, 20(3):542–542,2009.
[5] Yong Cheng, Wei Xu, Zhongjun He, Wei He, Hua Wu, Maosong Sun, and Yang Liu. Semi supervised learning for neural machine translation. arXiv preprint arXiv:1606.04596, 2016.
[6] François Chollet. Xception: Deep learning with depthwise separable convolutions. In Proceed ings of the IEEE conference on computer vision and pattern recognition, pages 1251–1258,2017.
网友评论