最近在做实验的时候遇到,代码需要运用已经训练好的模型,进行一些数据的计算,并给出计算结果,这部分代码的结构大致如下:
# new tensor define
new_tensor = …
# old model defin
old_model = model()
# load checkpoint
var_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=…)
restorer_attrib = tf.train.Saver(var_to_restore)
ckpt = tf.train.get_checkpoint_state(…)
if ckpt and ckpt.model_checkpoint_path:
restorer_attrib.restore(sess, ckpt.model_checkpoint_path)
…
# initial value
sess.run(tf.global_variables_initializer())
# run sess get value
sess.run([…])
但是代码在运行的时候发现,同样一组数据,每次运行代码后,获得的从之前训练好的模型中得到的输出,每次运行结果都不同,检查中发现系列问题:
1.使用的TensorFlow中的tensorflow.contrib.slim.nets中的resnet_v2中,在定义resnet_v2中,测试中需要将is_training参数设置为False
2.测试中,要将网络中使用的tf.nn.dropout的keep_prob为1
在设置完这些内容后,发现依旧还是每次获得不同结果,经过仔细检查发现问题来源在于,在load checkpoint的代码完成之后,代码中运行了下列代码
sess.run(tf.global_variables_initializer())
这样导致,载入的checkpoint的数据失效,所有数据重新进行了初始化,导致模型是个随机的取值,从而每次运行获得的结果不同。
解决的办法也很简单,在load checkpoint的操作之前,先运行初始化的代码,然后再导入checkpoint.
正确的代码结构应该如下:
# new tensor define
new_tensor = …
# old model defin
old_model = model()
# initial value
sess.run(tf.global_variables_initializer())
# load checkpoint
var_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=…)
restorer_attrib = tf.train.Saver(var_to_restore)
ckpt = tf.train.get_checkpoint_state(…)
if ckpt and ckpt.model_checkpoint_path:
restorer_attrib.restore(sess, ckpt.model_checkpoint_path)
…
# run sess get value
sess.run([…])
网友评论