美文网首页
如何使用Tensorflow保存或者加载模型(三) -- Ker

如何使用Tensorflow保存或者加载模型(三) -- Ker

作者: Alex_Cen | 来源:发表于2019-08-22 09:44 被阅读0次

    1.背景

    Keras是一个非常易于上手,好用的深度学习框架。不仅容易构建模型,而且容易保存模型。目前,Keras已经被纳入到Tensorflow框架中了,因此站长打算在介绍Tensorflow模型保存的时候,可以一并把Keras的模型训练和保存也介绍了。

    1.1 模型构建

    下面可以以一个简单的demo模型来作为说明,构建模型结构。

    # -*- coding: utf-8 -*-
    # @Time    : 2019-08-03 17:46
    # @Author  : AlexCen
    # @Blog    :http://www.alexcen.com/
    from __future__ import absolute_import, division, print_function, unicode_literals
    from tensorflow import keras
    from tensorflow.keras import layers
    
    
    ##1.构建模型##
    inputs = keras.Input(shape=(784, ), name='digits')
    x = layers.Dense(64, activation='relu', name='d1')(inputs)
    x = layers.Dense(64, activation='relu', name='d2')(x)
    outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='mlp')
    model.summary()
    

    1.2 数据加载和模型训练

    加载mnist的默认数据,然后进行训练

    # -*- coding: utf-8 -*-
    # @Time    : 2019-08-03 17:46
    # @Author  : AlexCen
    # @Blog    :http://www.alexcen.com/
    
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.reshape(60000, 784).astype('float32') / 255
    x_test = x_test.reshape(10000, 784).astype('float32') / 255
    model.compile(loss='sparse_categorical_crossentropy',optimizer=keras.optimizers.RMSprop())
    history = model.fit(x_train, y_train, batch_size=64, epochs=1)
    
    

    1.3 模型保存和输出

    将模型输出并保存。一行代码就完成了。

    # -*- coding: utf-8 -*-
    # @Time    : 2019-08-03 17:46
    # @Author  : AlexCen
    # @Blog    :http://www.alexcen.com/
    
    #模型预测
    prediction = model.predict(x_test)
    
    ##模型保存成h5文件
    model.save('tensorflow_model_save_restore/result/keras_model/test_model.h5')
    

    1.4 模型加载和预测

    模型的加载也是非常简单,一行代码即可。

    ##加载模型
    model_new = keras.models.load_model('tensorflow_model_save_restore/result/keras_model/test_model.h5')
    prediction_new = model_new.predict(x_test)
    

    1.5 结果示例

    array([[2.0833343e-06, 5.7367635e-07, 3.0239375e-04, 8.2307472e-04,
            3.1286825e-07, 1.2593745e-05, 4.4974424e-09, 9.9865258e-01,
            1.1882040e-05, 1.9443125e-04],
           [9.6831842e-05, 1.4330833e-05, 9.9635977e-01, 2.7176763e-03,
            7.9741378e-09, 5.3533819e-04, 2.2054941e-04, 1.6272834e-07,
            5.5394452e-05, 7.3264599e-09],
           [7.0375630e-05, 9.7449875e-01, 8.4965276e-03, 2.2081137e-03,
            4.2712223e-04, 1.0506579e-03, 1.0271738e-03, 5.4973103e-03,
            4.8175761e-03, 1.9064705e-03],
           [9.9902606e-01, 7.6533127e-09, 4.4397468e-05, 7.0306480e-05,
            1.6069802e-06, 7.1709284e-05, 2.7319507e-05, 1.3904357e-05,
            3.0369791e-05, 7.1443513e-04],
           [5.1576592e-04, 8.4059438e-06, 1.3240698e-03, 1.4757921e-04,
            8.7131393e-01, 2.3408027e-04, 5.0422893e-04, 3.5719746e-03,
            6.8508921e-04, 1.2169483e-01]], dtype=float32)
    

    2.总结

    本文介绍如何通过keras框架来进行构建模型框架和模型保存,可见,keras是非常容易上手的工具,对于基本调用的使用者来说提供了友好的交互形式,不过,如果需要对模型结果进行特殊的调整的话,可能就不太方便了,建议有这类需求的同学还是去Tensorflow来构建模型。

    相关文章

      网友评论

          本文标题:如何使用Tensorflow保存或者加载模型(三) -- Ker

          本文链接:https://www.haomeiwen.com/subject/fkmgsctx.html