美文网首页
[Keras] ModelCheckpoint 无法保存多 gp

[Keras] ModelCheckpoint 无法保存多 gp

作者: DexterLei | 来源:发表于2018-05-18 17:18 被阅读662次

    问题描述

    在使用 callbacks.ModelCheckpoint() 并进行多 gpu 并行计算时,callbacks 函数会报错:

    TypeError: can't pickle ...(different text at different situation) objects
    

    这个错误形式其实跟使用多 gpu 训练时保存模型不当造成的错误比较相似:

    To save the multi-gpu model, use .save(fname) or .save_weights(fname)
    with the template model (the argument you passed to multi_gpu_model),
    rather than the model returned by multi_gpu_model.

    这个问题在我之前的文章中也有提到:[Keras] 使用Keras调用多GPU,并保存模型
    。显然,在使用检查点时,默认还是使用了 paralleled_model.save() ,进而导致错误。为了解决这个问题,我们需要自己定义一个召回函数。

    解决方法

    法一

    original_model = ...
    parallel_model = multi_gpu_model(original_model, gpus=n)
    
    class MyCbk(keras.callbacks.Callback):
    
        def __init__(self, model):
             self.model_to_save = model
    
        def on_epoch_end(self, epoch, logs=None):
            self.model_to_save.save('model_at_epoch_%d.h5' % epoch)
    
    cbk = MyCbk(original_model)
    parallel_model.fit(..., callbacks=[cbk])
    

    法二

    class ParallelModelCheckpoint(ModelCheckpoint):
        def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                     save_best_only=False, save_weights_only=False,
                     mode='auto', period=1):
            self.single_model = model
            super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)
    
        def set_model(self, model):
            super(ParallelModelCheckpoint,self).set_model(self.single_model)
    
    check_point = ParallelModelCheckpoint(single_model ,'best.hd5')
    

    法三

    class CustomModelCheckpoint(keras.callbacks.Callback):
    
        def __init__(self, model, path):
            self.model = model
            self.path = path
            self.best_loss = np.inf
    
        def on_epoch_end(self, epoch, logs=None):
            val_loss = logs['val_loss']
            if val_loss < self.best_loss:
                print("\nValidation loss decreased from {} to {}, saving model".format(self.best_loss, val_loss))
                self.model.save_weights(self.path, overwrite=True)
                self.best_loss = val_loss
    
    model.fit(X_train, y_train,
                  batch_size=batch_size*G, epochs=nb_epoch, verbose=0, shuffle=True,
                  validation_data=(X_valid, y_valid),
                  callbacks=[CustomModelCheckpoint(model, '/path/to/save/model.h5')])
    

    参考资料

    相关文章

      网友评论

          本文标题:[Keras] ModelCheckpoint 无法保存多 gp

          本文链接:https://www.haomeiwen.com/subject/bwogdftx.html