美文网首页我爱编程
tf的保存和恢复

tf的保存和恢复

作者: 耀話你知 | 来源:发表于2018-05-20 15:42 被阅读0次

    保存模型:

    import tensorflow as tf

    Prepare to feed input, i.e. feed_dict and placeholders

    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}

    Define a test operation that we will restore

    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #Create a saver object which will save all the variables
    saver = tf.train.Saver()
    #Run the operation by feeding input
    print(sess.run(w4,feed_dict))
    #Prints 24 which is sum of (w1+w2)*b1
    #Now, save the graph
    saver.save(sess, './my_test_model',global_step=1020)#global_step记录循环第几次

    必须强调的是:这里4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。那其他tf的op相关的就不用name了???或者默认是否是和变量名如w4相同?

    恢复和使用

    import tensorflow as tf
    sess=tf.Session()

    First let's load meta graph and restore weights

    saver = tf.train.import_meta_graph('my_test_model-1020.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))

    应该是把.meta前的那一部分都输入作为目录,此处是my_test_model-1020

    tf.train.latest_checkpoint()来自动获取最后一次保存的模型

    Access saved Variables directly

    print(sess.run('bias:0'))#0在这里的作用是???以及此时是否恢复了w?——恢复了占位符而已,见下文尝试。

    print(sess.run('w1:0'))提示没有赋值。因为只是个placeholder?那对于神经网络里训练好的模型的权重,怎么恢复来用?--variable的值会保存。上面print(sess.run('bias:0'))就是证明。

    This will print 2, which is the value of bias that we saved

    Now, let's access and create placeholders variables and

    create feed-dict to feed new data

    graph = tf.get_default_graph()#restore没有恢复图吗?为何还要再来一次?——应该是为了下面调用getxxx函数。所以创建这么一个对象。
    w1 = graph.get_tensor_by_name("w1:0")#要重新启动这个占位符,把模型/图里的w1赋值给一个本地变量,可以命名为w1,也可以是其他如ww1,便于后面的操作如feed_dict。
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}

    Now, access the op that you want to run.

    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
    print(sess.run(op_to_restore,feed_dict))

    This will print 60 which is calculated

    w3 = graph.get_tensor_by_name("w3:0")报错:"The name 'w3:0' refers to a Tensor which does not exist. The operation, 'w3', does not exist in the graph."

    总结:需要get_tensor_by_xxx应该都是为了创造本地变量从而feed数据而已,没有与本地交互的,如w3就不用再get,实际上已恢复在图中但不用交互。

    '''
    网上查到两种常用方法对比:
    方法1:
    保存
    定义变量
    使用saver.save()方法保存
    import tensorflow as tf
    import numpy as np
    W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
    b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
    init = tf.initialize_all_variables()
    saver = tf.train.Saver()
    with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"save/model.ckpt")

    载入
    定义变量
    使用saver.restore()方法载入
    import tensorflow as tf
    import numpy as np
    W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
    b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
    saver = tf.train.Saver()
    with tf.Session() as sess:
    saver.restore(sess,"save/model.ckpt")

    在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

    方法二、不需重新定义网络结构的方法
    '''

    定义模型

    input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
    input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

    w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
    b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
    w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
    b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
    hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)

    定义预测目标

    y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)

    创建saver

    saver = tf.train.Saver()# defaults to saving all variables - in this case w and b

    假如需要保存y,以便在预测时使用

    tf.add_to_collection('pred_network', y)
    sess = tf.Session()
    for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
    # 保存checkpoint, 同时也默认导出一个meta_graph
    # graph名为'my-model-{global_step}.meta'.
    saver.save(sess, 'my-model', global_step=step)

    with tf.Session() as sess:
    # new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
    # new_saver.restore(sess, 'my-save-dir/my-model-10000')
    new_saver = tf.train.import_meta_graph('mmodel.ckpt-25.meta')
    new_saver.restore(sess, 'mmodel.ckpt-25')

    tf.get_collection() 返回一个list. 但是这里只要第一个参数即可

    y = tf.get_collection('pred_network')[0]
    
    graph = tf.get_default_graph()
    

    # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。

    input_x = graph.get_operation_by_name('input_x').outputs[0]
    keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]
    

    # 使用y进行预测

    sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

    相关文章

      网友评论

        本文标题:tf的保存和恢复

        本文链接:https://www.haomeiwen.com/subject/qrrtjftx.html