美文网首页
TensorFlow(11) 保存与读取模型

TensorFlow(11) 保存与读取模型

作者: 操作系统 | 来源:发表于2017-08-08 20:39 被阅读0次

    保存模型

    import tensorflow as tf
    
    v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        print ("V1:",sess.run(v1))  
        print ("V2:",sess.run(v2))
        saver_path = saver.save(sess, "save/model.ckpt")
        print ("Model saved in file: ", saver_path) 
    

    读取模型

    import tensorflow as tf
    v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        saver.restore(sess, "save/model.ckpt")
        print ("V1:",sess.run(v1))  
        print ("V2:",sess.run(v2))
        print ("Model restored")
    

    相关文章

      网友评论

          本文标题:TensorFlow(11) 保存与读取模型

          本文链接:https://www.haomeiwen.com/subject/wgjxrxtx.html