美文网首页
用tensorflow做线性回归

用tensorflow做线性回归

作者: 圣_狒司机 | 来源:发表于2018-07-27 11:32 被阅读26次

    简述

    用Tf做线性回归有点小题大做了么,其实关键点只是练习优化器的使用,

    loss = tf.reduce_mean((y-y_data)**2)
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = optimizer.minimize(loss)
    

    tf的这三句话比自己造轮子至少省了一个月时间。

    代码

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    DATA_NUMBER = 100
    
    #1 构造数据集
    x_data = np.linspace(-5,5,DATA_NUMBER)
    y_data = 0.37*X + +2.7+1.6*np.sin(np.random.random((DATA_NUMBER)))*np.cos(np.random.random((DATA_NUMBER)))
    
    #2 构造graph
    w = tf.Variable(np.random.random(),dtype=np.float32)
    b = tf.Variable(np.random.random(),dtype=np.float32)
    y = w*x_data + b
    
    #3 优化器
    loss = tf.reduce_mean((y-y_data)**2)
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = optimizer.minimize(loss)
    
    #4 会话
    with tf.Session() as sess:
        init_var = tf.global_variables_initializer()
        sess.run(init_var)
        print(w.eval(),b.eval(),loss.eval())
        for i in range(50000):
            sess.run(train)
            w_value = w.eval()
            b_value = b.eval()
            if i%5000 == 0:print(w_value,b_value,loss.eval())
    
    #5 展示
    plt.scatter(x_data,y_data,c='r')
    plt.plot(x_data,w_value*x_data + b_value)
    

    展示

    线性回归

    相关文章

      网友评论

          本文标题:用tensorflow做线性回归

          本文链接:https://www.haomeiwen.com/subject/dyvwmftx.html