在使用TensorFlow训练模型时,为了避免每次预测都要重新训练模型,模型保存必不可少。而在模型保存时,使用不同的参数可采用不同的保存模式。
TensorFlow使用tf.train.Saver()
来保存或重载模型。其常用参数有两个个:
-
var_list
指定要保存的变量名; -
max_to_keep
设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0。
一般情况下我们不会指定var_list
参数,表示我们保存模型的所有参数,也即var_list=tf.global_variables()
,但是有时候我们可能只想保存某些指定的变量,譬如仅保存trainable_variables 和 batch normlization参数:
# 提取batch nrom参数, 参考https://stackoverflow.com/questions/45800871/tensorflow-save-restore-batch-norm
bn_vars = [var for var in tf.global_variables() if ('moving_variance' in var.name) or ('moving_mean' in var.name)]
# 获取可训练参数
train_var = tf.trainable_variables()
# 保存时指定变量和ckpt个数,此时一般可以将保存文件大小压缩至原大小1/3
saver = tf.train.Saver(train_var+bn_vars, max_to_keep=1000)
# 最后,指定是什么时候保存模型
with tf.Session() as sess:
...
if epochs%10==0:
saver.save(sess, save_dir, epochs)
最后,一般会在指定文件夹下生存4个文件(参考知乎回答):
-
checkpoint
Checkpoint保存断点文件列表,可以用来迅速查找最近一次的断点文件; -
vggfcn_mom1to2_22w.ckpt-10000.data-00000-of-00001
数据(data)文件保存所有变量的值,即网络权值; -
vggfcn_mom1to2_22w.ckpt-10000.index
index文件为数据文件提供索引,存储的核心内容是以tensor name为键以BundleEntry为值的表格entries,BundleEntry主要内容是权值的类型、形状、偏移、校验和等信息。Index文件由data block/index block/Footer等组成,构建时主要涉及BundleWriter、TableBuilder、BlockBuilder几个类,除了BundleEntry的序列化,还涉及了tensor name的编码及优化(比如丢弃重复的前缀)和data block的snappy压缩; -
vggfcn_mom1to2_22w.ckpt-10000.meta
meta文件是MetaGraphDef序列化的二进制文件,保存了网络结构相关的数据,包括graph_def和saver_def等。
其中后三个(data,index,meta)是重载模型必须的。
重载较为简单,也可使用slim进行重载,这里不进行详细介绍:
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt")
最后,来看看怎么查看ckpt文件中的参数:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
checkpoint_path = "vggfcn_mom1to2_model.ckpt-335000"
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='',all_tensors=True)
# 输出:
tensor_name: global_step
335001.0
tensor_name: tower_0/softmax_cross_entropy_loss/value/avg
0.20133862
tensor_name: tower_0/total_loss/avg
0.23763658
tensor_name: vgg_16/conv1/conv1_1/biases
[ 0.02797224 0.07390981 -0.00385985 0.00785956 -0.01232299 -0.01137739
0.0253244 -0.07548576 0.00300218 -0.01330253 0.00169791 -0.0064972
-0.04955265 0.23835294 0.00348639 -0.00088801 -0.00086359 0.04471491
0.02329434 -0.08720596 -0.00614685 0.04078082 0.00396819 0.03650536
-0.00419048 0.02792025 -0.01622533 0.00038355 -0.02241365 0.04451815
0.14487104 0.09795988 0.02575901 0.01830461 -0.06216403 0.14097025
-0.00921478 -0.0047136 0.03897408 0.05741399 0.00417633 0.03451509
-0.01117235 0.0383427 -0.00806538 -0.00271318 -0.03132046 -0.06175949
0.01164324 -0.00922603 0.00940489 -0.08562332 0.0114077 0.00084499
0.0093553 -0.01873754 -0.01944772 0.02312844 -0.08838678 -0.01441999
0.10067106 0.02811741 -0.00628062 0.1013862 ]
tensor_name: vgg_16/conv1/conv1_1/biases/ExponentialMovingAverage
[ 0.02784752 0.07357726 -0.00385987 0.00785797 -0.01232291 -0.01145722
0.02526607 -0.07584356 0.00298003 -0.01330372 0.0016973 -0.00649742
-0.0496439 0.23820773 0.00345324 -0.00088819 -0.00087755 0.04459094
0.02319813 -0.08743148 -0.00614733 0.04070283 0.00393652 0.03630615
-0.00418991 0.02786526 -0.016225 0.00033586 -0.02241365 0.04432626
0.14477569 0.09788124 0.02568062 0.01826548 -0.06237179 0.14037122
-0.00921435 -0.0047203 0.03881586 0.0567292 0.00412849 0.03450596
-0.01117271 0.0383288 -0.00806614 -0.00271349 -0.03151926 -0.06203282
0.01159994 -0.00922479 0.00936983 -0.08602356 0.01136672 0.00076801
0.0092632 -0.01873754 -0.01944782 0.02309002 -0.08878716 -0.01442014
0.10042226 0.02782144 -0.00628062 0.10099234]
tensor_name: vgg_16/conv1/conv1_1/biases/Momentum
[ 1.1582047e-03 -3.2689439e-03 0.0000000e+00 -3.7641104e-05
0.0000000e+00 -2.8069332e-04 -3.4572899e-05 -8.5439772e-04
...
.pb格式的模型保存、重载以及.ckpt格式文件转换为.pb格式
# coding: utf-8
import tensorflow as tf
from tensorflow.python.framework import graph_util
import os
import numpy as np
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def restore_pb(img_data, pb_file_path=None):
'''restore model from .pb to infer
'''
if not pb_file_path:
pb_file_path = "test/inceptionV3_0.pb"
with tf.Graph().as_default():
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb_file_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return_elems = ['tower_0/InceptionV3/predictions/Softmax:0']
predictions = tf.import_graph_def(graph_def, return_elements=return_elems)
with tf.Session() as sess:
# print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
# for node in sess.graph.node:
# print(node.name)
sess.run(tf.global_variables_initializer())
image = sess.graph.get_tensor_by_name('import/x:0')
#label = sess.graph.get_tensor_by_name('y:0')
predictions = sess.run(predictions, feed_dict={image: img_data})
print('pb predictions:', predictions[0])
def restore_ckpt(img_data):
'''restore model from .ckpt to infer
'''
cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
with tf.Graph().as_default():
graph = tf.Graph()
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(graph=graph, config=config) as sess:
saver = tf.train.import_meta_graph(cpkt_meta)
saver.restore(sess, ckpt)
# print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
image = graph.get_tensor_by_name('x:0')
preds = graph.get_tensor_by_name('tower_0/InceptionV3/predictions/Softmax:0')
predictions = sess.run(preds, feed_dict={image: img_data})
print('ckpt predictions:', predictions)
def ckpt2pb():
'''convert .ckpt to .pb file
'''
# some path
output_nodes = ['tower_0/InceptionV3/predictions/Softmax']
cpkt_meta = 'test/inception_mom1to2_rand5.ckpt-0.meta'
ckpt = 'test/inception_mom1to2_rand5.ckpt-0'
pb_file_path = 'test/inceptionV3_ckpt2pb.pb'
with tf.Graph().as_default():
graph = tf.Graph()
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(graph=graph, config=config) as sess:
# restored model from ckpt
saver = tf.train.import_meta_graph(cpkt_meta)
saver.restore(sess, ckpt)
# save freeze graph into .pb file
graph_def = tf.get_default_graph().as_graph_def()
constant_graph = graph_util.convert_variables_to_constants(sess, graph_def, output_nodes)
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
img_data = cv2.imread('tumor_009_19327_175114_0.jpeg')
img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
img_data = (np.array(img_data).astype(np.float32)) / 256.0
img_data = np.reshape(img_data, [-1, 256, 256, 3])
restore_pb(img_data)
restore_ckpt(img_data)
ckpt2pb()
restore_pb(img_data, pb_file_path='test/inceptionV3_ckpt2pb.pb')
网友评论