美文网首页机器学习Deep LearningTensorFlow
tensorflow的基本用法(十)——保存神经网络参数和加载神

tensorflow的基本用法(十)——保存神经网络参数和加载神

作者: SnailTyan | 来源:发表于2017-04-20 19:57 被阅读887次

    文章作者:Tyan
    博客:noahsnail.com  |  CSDN  |  简书

    本文主要是使用tensorfl保存神经网络参数和加载神经网络参数。

    #!/usr/bin/env python
    # _*_ coding: utf-8 _*_
    
    import tensorflow as tf
    import numpy as np
    
    
    # 保存神经网络参数
    def save_para():
        # 定义权重参数
        W = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype = tf.float32, name = 'weights')
        # 定义偏置参数
        b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = 'biases')
        # 参数初始化
        init = tf.global_variables_initializer()
        # 定义保存参数的saver
        saver = tf.train.Saver()
    
        with tf.Session() as sess:
            sess.run(init)
            # 保存session中的数据
            save_path = saver.save(sess, 'my_net/save_net.ckpt')
            # 输出保存路径
            print 'Save to path: ', save_path
    
    # 恢复神经网络参数
    def restore_para():
        # 定义权重参数
        W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = tf.float32, name = 'weights')
        # 定义偏置参数
        b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = tf.float32, name = 'biases')
        # 定义提取参数的saver
        saver = tf.train.Saver()
    
        with tf.Session() as sess:
            # 加载文件中的参数数据,会根据name加载数据并保存到变量W和b中
            save_path = saver.restore(sess, 'my_net/save_net.ckpt')
            # 输出保存路径
            print 'Weights: ', sess.run(W)
            print 'biases:  ', sess.run(b)
    
    
    # save_para()
    restore_para()
    

    执行结果如下:

    # save
    Save to path:  my_net/save_net.ckpt
    
    
    # restore
    Weights:  [[ 1.  2.  3.]
     [ 4.  5.  6.]]
    biases:   [[ 1.  2.  3.]]
    

    相关文章

      网友评论

        本文标题:tensorflow的基本用法(十)——保存神经网络参数和加载神

        本文链接:https://www.haomeiwen.com/subject/dlzlzttx.html