TensorFlow 模型持久化

作者: fc20bac3035f | 来源:发表于2017-12-22 12:04 被阅读48次

title: TensorFlow 模型持久化
date: 2017-09-25 14:00:00
tags:

  • tensorflow
    categories: tensorflow

为了让训练结果可以复用,下面介绍如何将训练得到的网络模型持久化。

代码实现

tf.train.Saver

有关[tf.train.Saver]类的官网文档见这里或者GitHub

简单实现

保存代码

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")

运行代码后输出

'Saved_model/model.ckpt'

观察当前文件夹,新生成了Saved_model文件夹,其中包含四个文件:

  • checkpoint:保存了一个目录下所有的模型文件列表。
  • model.ckpt.data-00000-of-00001:保存了TensorFlow当前参数值。
  • model.ckpt.index:保存了TensorFlow当前参数名。
  • model.ckpt.meta:保存了TensorFlow计算图的结构。

加载代码

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

输出如下:

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[ 3.]

该代码首先定义了Tensorflow计算图中的所有运算结构,然后从本地文件中读入变量的值,不需要初始化变量。

加载持久化的图

若我们不希望代码中再次定义所有的结构,则可以加载已经保存了的图结构。代码如下:

import tensorflow as tf

saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

输出如下

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[ 3.]

上述所有代码,默认保存和加载了TensorFlow计算图中定义的全部变量。

保存指定变量

保存代码

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

saver = tf.train.Saver([v1])

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

上述程序会出错,报错信息如下:

tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value v2

读取时对变量重命名

保存代码如下:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")

利用字典来重命名变量,key为结构图中的变量name,value为本地变量。加载代码如下:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})

保存和加载滑动平均模型

使用变量重命名方式

保存代码如下:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
    print(variables.name)
    
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
    print(variables.name)

saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print(sess.run([v, ema.average(v)]))

输出如下

v:0
v:0
v/ExponentialMovingAverage:0
[10.0, 0.099999905]

加载代码,因为滑动平均模型的特性,读取变量v的值,实际是要读取变量v的滑动平均值。

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print(sess.run(v))

输出如下

INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

使用variables_to_restore

为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore(Docs,Github)函数来生成tf.train.Saver类所需要的变量重命名字典。

代码如下:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print(sess.run(v))

输出如下:

{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999

PB文件保存

保存

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

graph_def = tf.get_default_graph().as_graph_def():导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。

graph_util.convert_variables_to_constants:将图中的变量和取值转化为常量。

此时只生成了一个文件combined_model.pb

输出

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.

加载代码

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"
   
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print(sess.run(result))

输出

[array([ 3.], dtype=float32)]

持久化原理和数据格式

TensorFlow保存的文件为Protocol Buffer形式的。下面首页介绍这种格式的文件。

Protocol Buffer

Protocol Buffer是Google开发的处理结构化数据的工具。类似的还有XML、JSON。

比如需要保存以下的一些结构化信息:

name: 张三
id: 12345
email: zhangsan@abc.com

XML保存:

<user>
    <name>张三</name>
    <id>12345</id>
    <email>zhangsan@abc.com</email>
</user>

JSON保存

{
    "name": "张三",
    "id": "12345",
    "email": "zhangsan@abc.com",
}

Protocol Buffer与这两者的区别:

  • XML和JSON格式的数据,序列化后为可读的字符串,该字符串中包含所有信息。
  • Protocol Buffer序列化后为不可读的二进制流,使用Protocol Buffer需先定义数据的格式(schema),还原数据时也需要相应的格式。
  • Protocol Buffer序列化后的数据比XML或JSON小3到10倍,解析时间快20到100倍。

格式schema文件定义如下:

message user{
    optional string name = 1;
    required int32 id = 2;
    repeated string email = 3;
}

.ckpt.meta —— MetaGraphDef

TensorFlow是一个通过图的形式来表述计算的编程系统,TensorFlow程序中的所有计算都会被表达为计算图上的节点。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。

类型定义如下,详见Github

message MetaGraphDef{
    MetaInfoDef meta_info_def = 1;
    GraphDef graph_def = 2;
    SaverDef saver_def = 3;
    map<string, CollectionDef> collection_def = 4;
    map<string, SignatureDef> signature_def = 5;
    repeated AssetFileDef asset_file_def = 6;
}

以上信息都保存在了model.ckpt.meta文件中,此为二进制文件,无法直接查看。为了方便调试,TensorFlow提供了export_meta_graph函数,支持以Json格式导出Protocol Buffer。代码如下

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result1 = v1 + v2

saver = tf.train.Saver()
saver.export_meta_graph("Saved_model/model.ckpt.meta.json", as_text=True)

查看Json文件

meta_info_def {
  stripped_op_list {
    op {
      name: "Add"
      input_arg {
        name: "x"
        type_attr: "T"
      }
      input_arg {
        name: "y"
        type_attr: "T"
      }
      output_arg {
        name: "z"
        type_attr: "T"
      }
      attr {
        name: "T"
        type: "type"
        allowed_values {
          list {
            type: DT_HALF
            type: DT_FLOAT
            type: DT_DOUBLE
            type: DT_UINT8
            type: DT_INT8
            type: DT_INT16
            type: DT_INT32
            type: DT_INT64
            type: DT_COMPLEX64
            type: DT_COMPLEX128
            type: DT_STRING
          }
        }
      }
    }
    op {
      name: "Assign"
      input_arg {
        name: "ref"
        type_attr: "T"
        is_ref: true
      }
      input_arg {
        name: "value"
        type_attr: "T"
      }
      output_arg {
        name: "output_ref"
        type_attr: "T"
        is_ref: true
      }
      attr {
        name: "T"
        type: "type"
      }
      attr {
        name: "validate_shape"
        type: "bool"
        default_value {
          b: true
        }
      }
      attr {
        name: "use_locking"
        type: "bool"
        default_value {
          b: true
        }
      }
      allows_uninitialized_input: true
    }
    op {
      name: "Const"
      output_arg {
        name: "output"
        type_attr: "dtype"
      }
      attr {
        name: "value"
        type: "tensor"
      }
      attr {
        name: "dtype"
        type: "type"
      }
    }
    op {
      name: "Identity"
      input_arg {
        name: "input"
        type_attr: "T"
      }
      output_arg {
        name: "output"
        type_attr: "T"
      }
      attr {
        name: "T"
        type: "type"
      }
    }
    op {
      name: "NoOp"
    }
    op {
      name: "RestoreV2"
      input_arg {
        name: "prefix"
        type: DT_STRING
      }
      input_arg {
        name: "tensor_names"
        type: DT_STRING
      }
      input_arg {
        name: "shape_and_slices"
        type: DT_STRING
      }
      output_arg {
        name: "tensors"
        type_list_attr: "dtypes"
      }
      attr {
        name: "dtypes"
        type: "list(type)"
        has_minimum: true
        minimum: 1
      }
      is_stateful: true
    }
    op {
      name: "SaveV2"
      input_arg {
        name: "prefix"
        type: DT_STRING
      }
      input_arg {
        name: "tensor_names"
        type: DT_STRING
      }
      input_arg {
        name: "shape_and_slices"
        type: DT_STRING
      }
      input_arg {
        name: "tensors"
        type_list_attr: "dtypes"
      }
      attr {
        name: "dtypes"
        type: "list(type)"
        has_minimum: true
        minimum: 1
      }
      is_stateful: true
    }
    op {
      name: "VariableV2"
      output_arg {
        name: "ref"
        type_attr: "dtype"
        is_ref: true
      }
      attr {
        name: "shape"
        type: "shape"
      }
      attr {
        name: "dtype"
        type: "type"
      }
      attr {
        name: "container"
        type: "string"
        default_value {
          s: ""
        }
      }
      attr {
        name: "shared_name"
        type: "string"
        default_value {
          s: ""
        }
      }
      is_stateful: true
    }
  }
  tensorflow_version: "1.3.0"
  tensorflow_git_version: "v1.3.0-rc2-20-g0787eee"
}
graph_def {
  node {
    name: "Const"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: 1
            }
          }
          float_val: 1.0
        }
      }
    }
  }
  node {
    name: "v1"
    op: "VariableV2"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "container"
      value {
        s: ""
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: 1
          }
        }
      }
    }
    attr {
      key: "shared_name"
      value {
        s: ""
      }
    }
  }
  node {
    name: "v1/Assign"
    op: "Assign"
    input: "v1"
    input: "Const"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "v1/read"
    op: "Identity"
    input: "v1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "Const_1"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: 1
            }
          }
          float_val: 2.0
        }
      }
    }
  }
  node {
    name: "v2"
    op: "VariableV2"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "container"
      value {
        s: ""
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: 1
          }
        }
      }
    }
    attr {
      key: "shared_name"
      value {
        s: ""
      }
    }
  }
  node {
    name: "v2/Assign"
    op: "Assign"
    input: "v2"
    input: "Const_1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "v2/read"
    op: "Identity"
    input: "v2"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "add"
    op: "Add"
    input: "v1/read"
    input: "v2/read"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "save/Const"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
          }
          string_val: "model"
        }
      }
    }
  }
  node {
    name: "save/SaveV2/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 2
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 2
            }
          }
          string_val: "v1"
          string_val: "v2"
        }
      }
    }
  }
  node {
    name: "save/SaveV2/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 2
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 2
            }
          }
          string_val: ""
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/SaveV2"
    op: "SaveV2"
    input: "save/Const"
    input: "save/SaveV2/tensor_names"
    input: "save/SaveV2/shape_and_slices"
    input: "v1"
    input: "v2"
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/control_dependency"
    op: "Identity"
    input: "save/Const"
    input: "^save/SaveV2"
    attr {
      key: "T"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@save/Const"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
          }
        }
      }
    }
  }
  node {
    name: "save/RestoreV2/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: "v1"
        }
      }
    }
  }
  node {
    name: "save/RestoreV2/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/RestoreV2"
    op: "RestoreV2"
    input: "save/Const"
    input: "save/RestoreV2/tensor_names"
    input: "save/RestoreV2/shape_and_slices"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            unknown_rank: true
          }
        }
      }
    }
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/Assign"
    op: "Assign"
    input: "v1"
    input: "save/RestoreV2"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "save/RestoreV2_1/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: "v2"
        }
      }
    }
  }
  node {
    name: "save/RestoreV2_1/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/RestoreV2_1"
    op: "RestoreV2"
    input: "save/Const"
    input: "save/RestoreV2_1/tensor_names"
    input: "save/RestoreV2_1/shape_and_slices"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            unknown_rank: true
          }
        }
      }
    }
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/Assign_1"
    op: "Assign"
    input: "v2"
    input: "save/RestoreV2_1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "save/restore_all"
    op: "NoOp"
    input: "^save/Assign"
    input: "^save/Assign_1"
  }
  versions {
    producer: 24
  }
}
saver_def {
  filename_tensor_name: "save/Const:0"
  save_tensor_name: "save/control_dependency:0"
  restore_op_name: "save/restore_all"
  max_to_keep: 5
  keep_checkpoint_every_n_hours: 10000.0
  version: V2
}
collection_def {
  key: "trainable_variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}
collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
    }
  }
}

meta_info_def属性

保存了Tensorflow计算图中的元数据和程序中所有用到的运算方法的信息。

定义如下:

message MetaInfoDef {
    string meta_graph_version = 1;
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
    string tensorflow_version = 5;
    string tensorflow_git_version = 6;

OpList定义见Github

在OpDef中的attr属性中,必须包含name为T的属性,指定了运算输入输出允许的参数类型。

graph_def

GraphDefNodeDef

主要记录计算图上的节点信息。

saver_def

SaverDef

主要记录持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史纪录等。

collection_def

CollectionDef

维护不同的集合,是一个从集合名称到集合内容的映射。

.ckpt

TensorFlow采用tf.train.NewCheckpointReader来读取ckpt文件中的所有变量信息。

import tensorflow as tf

reader = tf.train.NewCheckpointReader("Saved_model/model.ckpt")

all_variables = reader.get_variable_to_shape_map()

for variable_name in all_variables:
    print(variable_name, all_variables[variable_name])

print("Value for variable v1 is ", reader.get_tensor("v1"))

tf.train.NewCheckpointReader读取ckpt文件中的所有变量。
variable_name为变量名称
all_variables[variable_name]为变量维度

输出如下:

v2 [1]
v1 [1]
Value for variable v1 is  [ 1.]

checkpoint

tf.train.Saver类自动生成且维护,记录所有Tensorflow模型文件的文件名。可读

格式如下:

message CheckpointState{
    string model_checkpoint_path = 1;
    repeated string all_model_checkpoint_paths = 2;
}

实例如下:

model_checkpoint_path: "model.ckpt"
all_model_checkpoint_paths: "model.ckpt"

相关文章

  • Tensorflow相关

    使用 Tensorflow 的 DataSet 和 Iterator 读取数据 Tensorflow 模型持久化 ...

  • TensorFlow 模型持久化

    title: TensorFlow 模型持久化date: 2017-09-25 14:00:00tags: ten...

  • Tensorflow 模型持久化

    当我们使用 tensorflow 训练神经网络的时候,模型持久化对于我们的训练有很重要的作用。 如果我们的神经网络...

  • [tensorflow](三)MNIST数字识别问题

    20181125 qzd MNIST数据处理 神经网络模型训练及不同模型对比 变量管理 Tensorflow模型持久化

  • chapter 5

    设计神经网络的5种优化方法 tensorflow模型持久化 [图片上传失败...(image-652906-153...

  • Tensorflow MNIST for Android

    本篇博客主要介绍如何使用 tensorflow 通过 CNN 实现 MNIST 手写数字识别问题,并将模型持久化在...

  • iOS数据持久化

    Title: iOS数据持久化 ##数据持久化概念 数据持久化就是将内存中的数据模型转换为存储模型,以及将存储模型...

  • sikit-learn模型持久化

    sikit-learn模型持久化(导出)model persistence模型持久化。 1)使用pickle工具 ...

  • 分布式-缓存

    缓存 Memcached 不可持久化Redis 可持久化 Memcached Memcached数据访问模型添加新...

  • 模型持久化

    模型持久化应该使用joblib或者pickle 所以上面两者不是一致的意思 还需要joblib可以持久化 可以参照...

网友评论

    本文标题:TensorFlow 模型持久化

    本文链接:https://www.haomeiwen.com/subject/jymtgxtx.html