美文网首页
keras回调函数callback

keras回调函数callback

作者: poteman | 来源:发表于2019-08-03 17:30 被阅读0次
    • earlystopping,保存最佳模型,加载最佳模型。
    import tensorflow as tf
    print(tf.__version__)
    from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
    from tensorflow.keras.models import load_model
    
    # 定义callback
    filepath = 'my_model.h5'
    callbacks_list = [
      EarlyStopping(
      monitor='val_acc',
      patience=3,
      ),
      ModelCheckpoint(
      filepath=filepath,
      monitor='val_acc',
      save_best_only=True,
      save_weights_only=False
      )
    ]
    
    # 加载数据及预处理
    mnist = tf.keras.datasets.fashion_mnist
    (training_images, training_labels), (test_images, test_labels) = mnist.load_data()
    training_images=training_images/255.0
    test_images=test_images/255.0
    
    # 定义模型
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(training_images, training_labels, epochs=30, callbacks=callbacks_list, validation_data=(test_images, test_labels))
    
    # 加载最优模型
    model = load_model(filepath)
    model.evaluate(test_images, test_labels)
    
    # 使用模型预测
    classifications = model.predict(test_images)
    print(classifications[0])
    print(test_labels[0])
    
    • callbacks
    import tensorflow as tf
    from tensorflow.keras.layers import Dense, Flatten
    mnist = tf.keras.datasets.mnist
    
    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    
    model = tf.keras.models.Sequential([
        Flatten(input_shape = (28, 28)),
        Dense(512, activation='relu'),
        Dense(10, activation='softmax'),
    ])
    
    class myCallback(tf.keras.callbacks.Callback):
      def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.99):
          print("\nReached 99% accuracy so cancelling training!")
          self.model.stop_training = True
    callback = myCallback()
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    model.fit(x_train, y_train, epochs=10, callbacks=[callback])
    
    model.evaluate(x_test, y_test)
    pred = model.predict(x_test)
    

    相关文章

      网友评论

          本文标题:keras回调函数callback

          本文链接:https://www.haomeiwen.com/subject/gvajdctx.html