美文网首页
(一)Tensorflow的ckpt转pb文件及Netron工具

(一)Tensorflow的ckpt转pb文件及Netron工具

作者: 神经网络爱好者 | 来源:发表于2019-10-31 15:25 被阅读0次

目录

1、读取ckpt文件
2、获取网络中所有的层的相关信息
3、ckpt转换为pb
4、pb文件的读取与测试
5、Netron: pb文件网络结构可视化及各层权重文件下载

1、读取ckpt文件

  tensorflow训练过程中一般产生的是ckpt文件,分为三个部分。以mobilenet为例,其model文件夹如下:

mobilenet.ckpt.data-00000-of-00001
mobilenet.ckpt.index
mobilenet.ckpt.meta
前两个文件是权重信息,第三个是结构信息;
三个文件可以通过公共名字mobilenet.ckpt表示

import tensorflow as tf
slim = tf.contrib.slim
# 提供了build_model、read_image函数
from predict import * 

# 自动分配物理设备,使得GPU上训练的结果能够在cpu等设备上使用
config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
model_path = './model/mobilenet.ckpt'
image_path = '3508smile.jpg'

sess = tf.Session(config=config)

# 模型定义
def create_model():
   image_input = tf.placeholder(tf.float32, [None, 128, 128, 3], name='inputs')
   logits = build_model(image_input) # 
   logits = tf.nn.softmax(logits)
   return image_input,logits

# 模型预测
def inference(image_path, input, output):
   image = read_image(image_path)
   image = sess.run(image)  # 将tensor转换为ndarray,feed_dict不接受tensor
   result = sess.run(output, feed_dict={input:image})
   print(result)

input, output = create_model()
print(input, output)
# 输出
# Tensor("inputs:0", shape=(?, 128, 128, 3), dtype=float32) 
# Tensor("Softmax:0", shape=(?, 4), dtype=float32)

# 模型初始化与权重加载
init = tf.global_variables_initializer()
new_saver = tf.train.Saver(tf.global_variables())
sess.run(init)
print('Load the pretrained model')
new_saver.restore(sess, model_path)

# 测试单张图片
inference(image_path,input,output)
sess.close()

  这里就完成了ckpt模型的读取工作,其中输出了output的名字,后面转化为pb文件时需要。而且,由于我是在双GPU上训练的模型,部署在CPU上,有些ops无法识别,因此配置了config,使得能够正常运行。

2、获取网络中所有的层的相关信息

# 获取原网络中的所有节点,名称与形状
graph = tf.get_default_graph()
op = graph.get_operations()
for i, m in enumerate(op):
   try:
      print(m.values()[0])  
   except:
      break

输出挑选几个解释如下:

Tensor("inputs:0", shape=(?, 128, 128, 3), dtype=float32)
输入tensor的名称、形状、类型

Tensor("MobilenetV1/Conv2d_0/weights:0", shape=(3, 3, 3, 32), dtype=float32_ref)
形如MobilenetV1/Conv2d_0/。。。,表示卷积核的相关信息

Tensor("MobilenetV1/MobilenetV1/Conv2d_0/Conv2D:0", shape=(?, 64, 64, 32), dtype=float32)
形如MobilenetV1/MobilenetV1/Conv2d_0。。。,表示特征图的相关信息

3、ckpt转换为pb

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['Softmax'])
with tf.gfile.FastGFile('./model.pb', mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  谷歌推荐的保存模型的方式是保存模型为 PB 文件,实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一,模型的大小也会大大减小(38MB -> 10MB)。

4、pb文件的读取与测试

# 读取pb模型并测试
with tf.gfile.FastGFile('model.pb', 'rb') as f:
   graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')  # 导入计算图

init = tf.global_variables_initializer()
sess.run(init)

# 获取模型的input与output
input = sess.graph.get_tensor_by_name('inputs:0')
output = sess.graph.get_tensor_by_name('Softmax:0')
# 测试
inference(image_path,input,output)
sess.close()

5、Netron: pb文件网络结构可视化及各层权重文件下载

  通过网页工具https://lutzroeder.github.io/netron/,可以可视化大多数框架的网络结构。

mobilenet.png
其中在INPUT的filter下可以下载各层的权重为npz格式。 filter.png

相关文章

网友评论

      本文标题:(一)Tensorflow的ckpt转pb文件及Netron工具

      本文链接:https://www.haomeiwen.com/subject/tdizvctx.html