tensorflow框架下
使用saver=tf.train.Saver()保存模型会输出以下四种文件
checkpoint 文本文件,记录了模型文件的路径信息列表
.ckpt.meta 保存了模型的计算图结构信息(模型的网络结构)
.ckpt.data-00000-of-00001 网络权重信息
.ckpt.index 保存了模型中的变量参数(权重)信息
模型加载方式
(1)
def restore_model_ckpt(ckpt_file_path):
sess =tf.Session()
saver =tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 指定目录就可以恢复所有变量信息
(2)
saver = tf.train.import_meta_graph(path_to_ckpt_meta)
saver.restore(sess, path_to_ckpt_data)
.pb文件是谷歌推荐的保存模型的方式
将模型参数固化到图文件中,里面保存了图结构+数据,合并了一些基础计算和删除了反向传播相关计算得到的protobuf协议文件,加载模型时只需要这一个文件就好
keras框架下
.h5 保存的模型参数或者模型
.json .yaml 保存的模型结构
.hdf5 保存的模型参数
keras(tensorflow backend)中可以通过如下方式加载模型
(1)
loaded_model = model_from_json(open('model_architecture-1.json').read())
loaded_model.load_weights('saved_models/weights-improvement-19-0.98100.hdf5', by_name=True)
#loaded_model.load_weights('my_model_weights.h5', by_name=True)
(2)
model = load_model('my_model.h5')
网友评论