美文网首页
Tensorflow Error笔记3

Tensorflow Error笔记3

作者: BookThief | 来源:发表于2017-07-11 21:52 被阅读0次

    愿天堂没有Tensorflow! 阿门。

    NotFoundError (see above for traceback): Key local3/weights not found in checkpoint

    这是一个困扰我好久的问题,在我们保存一个训练好的模型,然后找了一些测试数据来调用该模型测试模型的效果时,出现了上述错误,local3/weights可能会随机变化(比如conv1/weights)。下面调用模型的代码是Tensorflow官网上的。

    with tf.Session() as sess:
                     tf.get_variable_scope().reuse_variables()             
                     print("Reading checkpoints...")
                     ckpt = tf.train.get_checkpoint_state(logs_train_dir)
                     if ckpt and ckpt.model_checkpoint_path:
                         global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                         saver.restore(sess, ckpt.model_checkpoint_path)
                         print('Loading success, global_step is %s' % global_step)
                     else:
                         print('No checkpoint file found')
    

    看起来无懈可击,这个错误无从下手。再仔细读一下这个Error,有没有一种checkpoint模型保存的参数名字和实际网络模型参数的名字不一样的感觉?(哈哈,反正我有)。看一下自己的checkpoint和网络参数名字:

    checkpoint 参数名 参数名
    此时我们会产生这样一个大胆的想法(小姐姐,我想...):难道checkpoint里的参数名字和我们网络的参数名字不一样吗??
    可是如何去验证这样一个大胆的想法呢? 如何去看checkpoint里的参数名呢? 如何讨得小姐姐的芳心呢?(哦哦,跑题了QAQ)我们可以使用下面的代码:
    import os
    model_dir = '/home/mml/siamese_net/logs/train/'
    from tensorflow.python import pywrap_tensorflow
    #checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
    checkpoint_path = os.path.join(model_dir, "model.ckpt-9999")
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key)) 
    

    运行完上述代码后,发现水落石出:

    参数名和数值 参数名和数值

    果然,checkpoint参数名和网络的参数名是不一样的,当然会导致无法在checkpoint里找到local5,因为checkpoint里只有siamese/local5,所以只要修改统一参数名,即可顺利消除错误。

    相关文章

      网友评论

          本文标题:Tensorflow Error笔记3

          本文链接:https://www.haomeiwen.com/subject/uexvhxtx.html