美文网首页
22-模型的保存与加载

22-模型的保存与加载

作者: jxvl假装 | 来源:发表于2019-10-05 11:17 被阅读0次

    api

    api

    模型的保存是保存为checkpoint文件。summary保存的文件为events文件

    import tensorflow as tf
    def myregression():
        """
        自实现一个线性回归预测
        :return: None
        """
        with tf.variable_scope("variable"):
            #准备数据
            x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name="x_data")
            y_true = tf.matmul(x, [[0.7]]) + 0.8 #矩阵相乘必须是2维的
        with tf.variable_scope("model"):
            #建立线回归模型
            weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0, name="weight"))
            bias = tf.Variable(0.0, name="bias")
            y_predict = tf.matmul(x, weight) + bias
        with tf.variable_scope("loss"):
            #建立损失函数,均方误差
            loss = tf.reduce_mean(tf.square(y_predict-y_true)) #reduce_mean是计算平均值
        with tf.variable_scope("optimizer"):
            #梯度下降优化损失
            train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss) #梯度下降去进行优化,即最小化损失,所以后面加了minimize
    
        #1. 收集变量:一般在会话之前
        tf.summary.scalar("losses", loss)
        tf.summary.histogram("weights", weight)
    
        #2. 合并变量,写入事件文件
        #定义合并变量的op
        merged = tf.summary.merge_all()
    
        #定义一个初始化变量的op
        init_op = tf.global_variables_initializer()
    
        #定义一个保存模型的实例
        saver = tf.train.Saver()
    
        #通过会话运行程序
        with tf.Session() as sess:
            #初始化变量
            sess.run(init_op)
            #打印随机初始化的权重和偏置值
            print("随机初始化的参数权重为:\n", weight.eval(), "\n偏置为:\n", bias.eval())
            #运行优化
            #循环训练优化
            filewriter = tf.summary.FileWriter("./", graph=sess.graph)
            for i in range(1000):
                sess.run(train_op)
                print("优化",i,"次优化过后的参数权重为:", weight.eval(), " 偏置为:", bias.eval())
                #运行合并的tensor
                summary = sess.run(merged)
                #把每次的值写入文件
                filewriter.add_summary(summary, i)
                if i % 100 == 0:
                    saver.save(sess, "./model") #model是文件名
        return None
    import os
    def restoremodel():
        """
        加载模型
        :return:None
        """
        with tf.variable_scope("variable"):
            x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name="x_data")
            y_true = tf.matmul(x, [[0.7]]) + 0.8  # 矩阵相乘必须是2维的
        with tf.variable_scope("model"):
            weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0, name="weight"))
            bias = tf.Variable(0.0, name="bias")
            y_predict = tf.matmul(x, weight) + bias
        with tf.variable_scope("loss"):
            loss = tf.reduce_mean(tf.square(y_predict - y_true))  # reduce_mean是计算平均值
        with tf.variable_scope("optimizer"):
            train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(
                loss)  # 梯度下降去进行优化,即最小化损失,所以后面加了minimize
        # init_op = tf.global_variables_initializer()
        """注意:在恢复模型的时候,就不能再初始化所有变量"""
        saver = tf. train.Saver()
        with tf.Session() as sess:
            if os.path.exists("./checkpoint"):
                #加载模型,覆盖之前的参数
                saver.restore(sess, "./model")  #文件名即可,不需要加后缀。这里的文件名即为model
            # sess.run(init_op)
            for i in range(500):
                # sess.run(train_op)
                print("优化",i,"次优化过后的参数权重为:", weight.eval(), " 偏置为:", bias.eval())
    
        return None
    
    if __name__ == "__main__":
        # myregression()
        restoremodel()
    
    第一次训练后的数据 重新加载后的数据

    相关文章

      网友评论

          本文标题:22-模型的保存与加载

          本文链接:https://www.haomeiwen.com/subject/xyduuctx.html