1.背景
Keras是一个非常易于上手,好用的深度学习框架。不仅容易构建模型,而且容易保存模型。目前,Keras已经被纳入到Tensorflow框架中了,因此站长打算在介绍Tensorflow模型保存的时候,可以一并把Keras的模型训练和保存也介绍了。
1.1 模型构建
下面可以以一个简单的demo模型来作为说明,构建模型结构。
# -*- coding: utf-8 -*-
# @Time : 2019-08-03 17:46
# @Author : AlexCen
# @Blog :http://www.alexcen.com/
from __future__ import absolute_import, division, print_function, unicode_literals
from tensorflow import keras
from tensorflow.keras import layers
##1.构建模型##
inputs = keras.Input(shape=(784, ), name='digits')
x = layers.Dense(64, activation='relu', name='d1')(inputs)
x = layers.Dense(64, activation='relu', name='d2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs, name='mlp')
model.summary()
1.2 数据加载和模型训练
加载mnist的默认数据,然后进行训练
# -*- coding: utf-8 -*-
# @Time : 2019-08-03 17:46
# @Author : AlexCen
# @Blog :http://www.alexcen.com/
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss='sparse_categorical_crossentropy',optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train, batch_size=64, epochs=1)
1.3 模型保存和输出
将模型输出并保存。一行代码就完成了。
# -*- coding: utf-8 -*-
# @Time : 2019-08-03 17:46
# @Author : AlexCen
# @Blog :http://www.alexcen.com/
#模型预测
prediction = model.predict(x_test)
##模型保存成h5文件
model.save('tensorflow_model_save_restore/result/keras_model/test_model.h5')
1.4 模型加载和预测
模型的加载也是非常简单,一行代码即可。
##加载模型
model_new = keras.models.load_model('tensorflow_model_save_restore/result/keras_model/test_model.h5')
prediction_new = model_new.predict(x_test)
1.5 结果示例
array([[2.0833343e-06, 5.7367635e-07, 3.0239375e-04, 8.2307472e-04,
3.1286825e-07, 1.2593745e-05, 4.4974424e-09, 9.9865258e-01,
1.1882040e-05, 1.9443125e-04],
[9.6831842e-05, 1.4330833e-05, 9.9635977e-01, 2.7176763e-03,
7.9741378e-09, 5.3533819e-04, 2.2054941e-04, 1.6272834e-07,
5.5394452e-05, 7.3264599e-09],
[7.0375630e-05, 9.7449875e-01, 8.4965276e-03, 2.2081137e-03,
4.2712223e-04, 1.0506579e-03, 1.0271738e-03, 5.4973103e-03,
4.8175761e-03, 1.9064705e-03],
[9.9902606e-01, 7.6533127e-09, 4.4397468e-05, 7.0306480e-05,
1.6069802e-06, 7.1709284e-05, 2.7319507e-05, 1.3904357e-05,
3.0369791e-05, 7.1443513e-04],
[5.1576592e-04, 8.4059438e-06, 1.3240698e-03, 1.4757921e-04,
8.7131393e-01, 2.3408027e-04, 5.0422893e-04, 3.5719746e-03,
6.8508921e-04, 1.2169483e-01]], dtype=float32)
2.总结
本文介绍如何通过keras
框架来进行构建模型框架和模型保存,可见,keras
是非常容易上手的工具,对于基本调用的使用者来说提供了友好的交互形式,不过,如果需要对模型结果进行特殊的调整的话,可能就不太方便了,建议有这类需求的同学还是去Tensorflow来构建模型。
网友评论