如果在一个项目中同时导入多个模型,会报错,应该是graph冲突,所以需要给每个模型单独新建graph
在这里,tf.variable_scope里面的名称必须和保存的模型中的scope名是一致的
from a_model import model as model1
from b_model import model as model2
import tensorflow as tf
graph1=tf.Graph()
graph2=tf.Graph()
with tf.variable_scope('scope_a'):
m_a = model1()
...
with tf.variable_scope('scope_b'):
m_b = model2()
...
t_vars = tf.global_variables()
a_vars = [var for var in t_vars if var.name.startswith('scope_a')]
b_vars = [var for var in t_vars if var.name.startswith('scope_b')]
model1_path = 'model1/checkpoint.ckpt-000'
with tf.Session() as sess1:
with graph1.as_default():
saver1 = tf.train.Saver(a_vars)
saver1.restore(sess1, model1_path)
reader = tf.train.NewCheckpointReader(model1_path)
print(reader.debug_string().decode("utf-8"))
model2_path = 'model2/checkpoint.ckpt-000'
with tf.Session() as sess2:
with graph2.as_default():
saver2 = tf.train.Saver(b_vars)
saver2.restore(sess2, model2_path)
reader = tf.train.NewCheckpointReader(model2_path)
print(reader.debug_string().decode("utf-8"))
...
通过tf.train.NewCheckpointReader来打印载入的模型中所保存的参数以及变量名
参考https://blog.csdn.net/u010122972/article/details/79093479
https://blog.csdn.net/zcc_0015/article/details/86772122
https://stackoverflow.com/questions/53918715/tensorflow-save-one-of-multiple-sessions
网友评论