Keras Callbacks

作者: 四碗饭儿 | 来源:发表于2017-08-10 15:37 被阅读0次

    在每个training/epoch/batch结束时,如果我们想执行某些任务,例如模型缓存、输出日志、计算当前的auc等等,Keras中的callback就派上用场了。

    Example 记录每个batch的损失函数值

    
    import keras
     
    # 定义callback类
    class MyCallback(keras.callbacks.Callback):
        def on_train_begin(self, logs={}):
            self.losses = []
            return
    
        def on_batch_end(self, batch, logs={}): # batch 为index, logs为当前batch的日志acc, loss...
            self.losses.append(logs.get('loss')) 
            return
    
    # 定义模型model
    ...
    ...
    
    # 调用callback
    cb = MyCallback()
    
    # 训练模型
    model.fit(x_train, y_train, batch_size=32, epochs=10, callbacks=[cb])
    
    # 查看callback内容
    cb.losses
    
    

    如上述例子,我们可以继承keras.callbacks.Callback来定义自己的callback,只需重写其中的6个方法即可

    • on_train_begin
    • on_train_end
    • on_epoch_begin
    • on_epoch_end
    • on_batch_begin
    • on_batch_end

    可在这6个方法中定义自己想要的属性,通过self.model可以访问模型本身,self.params可以访问训练参数。

    可能有用的属性

    • self.validation_data validate数据集
    • self.validation_data[0] 为X
    • self.validation_data[1] 为y

    相关文章

      网友评论

        本文标题:Keras Callbacks

        本文链接:https://www.haomeiwen.com/subject/cuehrxtx.html