美文网首页
机器学习中用来防止过拟合的方法有哪些?

机器学习中用来防止过拟合的方法有哪些?

作者: yuanCruise | 来源:发表于2019-03-03 22:48 被阅读8次

    1.什么是过拟合

    首先用我自己的语言来解释下什么是过拟合:由于模型过于复杂,学习能力过强,而用于训练的数据相对于复杂模型来说比较简单,所有模型会去学习数据中隐含的噪声,导致模型学不到真正数据集的分布,如下图所示,红色线就是由于模型过分的拟合了训练数据集,导致泛化能力过差。而蓝色线才是真正的数据集的分布。


    2.抑制过拟合的策略

    简单浏览了下网络上的各种总结抑制过拟合的策略,大概有如下几点:

    • 数据增强
    • Early stopping
    • 增加噪声
    • 简化网络结构
    • Dropout
    • 贝叶斯方法

    当前问题下的最多赞回答基本解释了上述几个方法的原理

    而我自己平时应用的时候觉得有些策略对过拟合还挺有效,但又没包含在上述几个策略中,所以在这里给大家详细介绍下,希望能够帮助到你们~

    1.mixup

    mixup 论文地址
    mixup其实就是一种数据增强的方式,我之所以还要在这里介绍是因为这并不是一种常规的数据增强,所以在这里推荐一波~
    mixup是一个和数据无关的简单数据增强原则,其以线性插值的方式来构建新的训练样本和标签。最终对标签的处理如下公式所示,这很简单但对于增强策略来说又很不一般。

    \left ( x_{i},y_{i} \right )\left ( x_{j},y_{j} \right )两个数据对是原始数据集中的训练样本对(训练样本和其对应的标签)。其中\lambda是一个服从B分布的参数,\lambda\sim Beta\left ( \alpha ,\alpha \right ) 。Beta分布的概率密度函数如下图所示,其中\alpha \in \left [ 0,+\infty \right ]


    因此
    2.label smooth

    label smooth代码github传送门
    原理介绍
    在多分类训练任务中,输入图片经过神级网络的计算,会得到当前输入图片对应于各个类别的置信度分数,这些分数会被softmax进行归一化处理,最终得到当前输入图片属于每个类别的概率。

    image

    之后在使用交叉熵函数来计算损失值:

    最终在训练网络时,最小化预测概率和标签真实概率的交叉熵,从而得到最优的预测概率分布。在此过程中,为了达到最好的拟合效果,最优的预测概率分布为:

    image

    也就是说,网络会驱使自身往正确标签和错误标签差值大的方向学习,在训练数据不足以表征所以的样本特征的情况下,这就会导致网络过拟合。

    label smoothing原理

    label smoothing的提出就是为了解决上述问题。最早是在Inception v2中被提出,是一种正则化的策略。其通过"软化"传统的one-hot类型标签,使得在计算损失值时能够有效抑制过拟合现象。如下图所示,label smoothing相当于减少真实样本标签的类别在计算损失函数时的权重,最终起到抑制过拟合的效果。

    1.label smoothing将真实概率分布作如下改变:


    其实更新后的分布就相当于往真实分布中加入了噪声,为了便于计算,该噪声服从简单的均匀分布。

    2.与之对应,label smoothing将交叉熵损失函数作如下改变:

    image

    3.与之对应,label smoothing将最优的预测概率分布作如下改变:

    image

    阿尔法可以是任意实数,最终通过抑制正负样本输出差值,使得网络能有更好的泛化能力。

    代码实现

    1.修改caffe.proto文件
    编辑src/caffe/proto/caffe.proto文件,主要是在原有的LossParameter字段上添加了label_smooth字段。

    
    message LossParameter {
      optional int32 ignore_label = 1;
      enum NormalizationMode {
        FULL = 0;
        VALID = 1;
        BATCH_SIZE = 2;
        NONE = 3;
      }
      optional NormalizationMode normalization = 3 [default = VALID];
      optional bool normalize = 2;
      optional float label_smooth = 4;
    }
    
    

    2.导入hpp/cpp/cu文件
    softmax_loss_layer.hpp文件添加到include/caffe/layers/文件夹下。
    softmax_loss_layer.cpp文件添加到src/caffe/layers/文件夹下。
    softmax_loss_layer.cu文件添加到src/caffe/layers/文件夹下。

    4.编译
    返回到caffe的根目录,使用make指令(下面几个都可以,任选一个),即可。

    make
    make -j
    make -j16
    make -j32    // 这里j后面的数字与电脑配置有关系,可以加速编译
    

    5.使用

    layer{
       name:"loss"
       type:"SoftmaxwithLoss"
       bottom:"fc"
       bottom:"label"
       top:"loss"
       loss_param{
          label_smooth:0.1
        }
    }
    
    3.知识蒸馏

    知识蒸馏论文地址

    知识蒸馏代码github传送门

    Hinton的文章《Distilling the Knowledge in a Neural Network》首次提出了知识蒸馏的概念,通过引入教师网络用以诱导学生网络的训练,实现知识迁移。所以其本质上和迁移学习有点像,但实现方式是不一样的。用“蒸馏”这个词来形容这个过程是相当形象的。用下图来解释这个过程。

    教师网络:大规模,参数量大的复杂网络模型。难以应用到设备端的模型。
    学生网络:小规模,参数量小的精简网络模型。可应用到设备端的模型,俗称可落地模型。

    我们可以认为教师网络是一个混合物,网络复杂的结构就是杂质,是我们不需要用到的东西,而网络学到的概率分布就是精华,是我们需要的。如上图所示,对于教师网络的蒸馏过程,我们可以形象的认为是通过温度系数T,将复杂网络结构中的概率分布蒸馏出来,并用该概率分布来指导精简网络进行训练。整个通过温度系数T的蒸馏过程由如下公式实现:


    从上述公式中可以看出,T值越大,概率分布越软(很多网上的博客都这么说)。其实我个人认为这就是在迁移学习的过程中添加了扰动,从而使得精简网络在借鉴学习的时候更有效,泛化能力更强,这其实就是一种抑制过拟合的策略,和其他抑制过拟合策略在原理上是一致的。

    蒸馏后学习策略
    在第一部分中我们介绍了蒸馏的整个过程,那么在蒸馏结束后,精简网络就要开始跟着负责网络的概率分布进行学习了,在这个过程中是使用KL散度来监督这个学习过程的。接下来简单介绍下KL散度的原理。


    上述公式为KL散度的定义式,我们最终的学习目标是学生网络能够学习到教师网络的概率分布,也就是两者的概率分布能够尽可能的相同。而根据KL散度的原理为T_Prob和S_Prob越接近,KL散度值越小。基于KL散度的这个原理,我们才可以利用这个指标来作为损失函数的计算策略。

    代码介绍

    由于知识蒸馏策略是基于SoftmaxLoss的,因此我们利用caffe实现时,只需要在SoftmaxLoss的基础上,添加一个教室网络,温度系数以及KL散度的计算即可。
    (一)在头文件中添加教师网络的定义


    (二)在头文件中添加温度系数的定义


    (三)在C文件中添加KL散度计算策略


    4.迁移学习的特例

    为什么说迁移学习也可以解决过拟合呢,这完全是我的经验之谈了。我们都知道过拟合的网络所学习到的参数都是错误的,完全不能够表征数据集的特征分布。所以当网络过于复杂,数据集又过于简单的情况下,我们可以使用特殊的迁移学习策略。
    常规的迁移学习,就是将已经在别的数据集上训练好的网络参数拷贝到当前简单的数据集的训练。但如果仅仅是这样,并不能够有效的抑制过拟合,因为网络在当前简单数据集上训练,仍然会把参数练坏。所以我就想了一个很直观的策略:既然网络继续训练会把参数练坏,那我直接控制一部分参数,使其不参与更新不久可以抑制参数被破坏了嘛。哎,最后还真的有效~ 至于如何抑制,以及抑制那几层会有效! 嘿嘿,这里卖个关子,而且这其实是实验性的东西光靠说也不是很好阐述,如果感兴趣可以关注我的同名公众号,欢迎来交流!哈哈哈哈~

    相关文章

      网友评论

          本文标题:机器学习中用来防止过拟合的方法有哪些?

          本文链接:https://www.haomeiwen.com/subject/ezxvuqtx.html