美文网首页
keras训练早停法EarlyStopping

keras训练早停法EarlyStopping

作者: carebon | 来源:发表于2021-09-25 10:26 被阅读0次

EarlyStopping的使用与技巧
一般是在model.fit函数中调用callbacks,fit函数中有一个参数为callbacks。注意这里需要输入的是list类型的数据,所以通常情况只用EarlyStopping的话也要是[EarlyStopping()]

当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据)。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合:当网络在训练集上表现越来越好,错误率越来越低的时候,实际上在某一刻,它在测试集的表现已经开始变差。
from:https://www.datalearner.com/blog/1051537860479157

常用的防止过拟合的方法是对模型加正则项,如L1、L2,dropout,但深度神经网络希望通过加深网络层次减少优化的参数,同时可以得到更好的优化结果,Early stopping的使用可以通过在模型训练整个过程中截取保存结果最优的参数模型,防止过拟合。

迭代次数增多后,达到一定程度后产生过拟合。从图中可以看出,训练集精度一直在提升,但是test set的精度在上升后下降。若是在early stopping的位置保存模型,则不必反复训练模型,即可找到最优解。

monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。但是因为笔者用的是5折交叉验证,没有单设验证集,所以只能用’acc’了。
min_delta:增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了你的容忍程度。例如笔者的monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。加上观察到训练过程中存在抖动的情况(即先下降后上升),所以适当增大容忍程度,最终设为0.003%。
patience:能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。笔者在引入EarlyStopping之前就已经得到可以接受的结果了,EarlyStopping算是锦上添花,所以patience设的比较高,设为抖动epoch number的最大值。
mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。笔者的monitor是’acc’,所以mode=’max’。
min_delta和patience都和“避免模型停止在抖动过程中”有关系,所以调节的时候需要互相协调。通常情况下,min_delta降低,那么patience可以适当减少;min_delta增加,那么patience需要适当延长;反之亦然。

代码示例

class RocAucMetricCallback(keras.callbacks.Callback):
    def __init__(self, predict_batch_size=1024):
        super(RocAucMetricCallback, self).__init__()
        self.predict_batch_size = predict_batch_size
 
    def on_batch_begin(self, batch, logs={}):
        pass
 
    def on_batch_end(self, batch, logs={}):
        pass
 
    def on_train_begin(self, logs={}):
        if not ('val_roc_auc' in self.params['metrics']):
            self.params['metrics'].append('val_roc_auc')
 
    def on_train_end(self, logs={}):
        pass
 
    def on_epoch_begin(self, epoch, logs={}):
        pass
 
    def on_epoch_end(self, epoch, logs={}):
        logs['roc_auc'] = float('-inf')
        if (self.validation_data):
            logs['roc_auc'] = roc_auc_score(self.validation_data[1],
                                            self.model.predict(self.validation_data[0],
                                                               batch_size=self.predict_batch_size))
            print('ROC_AUC - epoch:%d - score:%.6f' % (epoch + 1, logs['roc_auc']))
    my_callbacks = [
        RocAucMetricCallback(),  # include it before EarlyStopping!
        EarlyStopping(monitor='roc_auc', patience=20, verbose=2, mode='max')
    ]
 
    mlp.fit(X_train_pre, y_train_pre,
            batch_size=512,
            epochs=500,
            class_weight="auto",
            callbacks=my_callbacks,
            validation_data=(X_train_pre_val, y_train_pre_val))

————————————————
版权声明:本文为CSDN博主「积极向上的墨鱼仔」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/zwqjoy/article/details/86677030
————————————————
版权声明:本文为CSDN博主「不会飞的鹰08」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_41449637/article/details/90201206
EarlyStopping的参数:

相关文章

网友评论

      本文标题:keras训练早停法EarlyStopping

      本文链接:https://www.haomeiwen.com/subject/tpfxnltx.html