美文网首页
使用Tensorflow拟合直线

使用Tensorflow拟合直线

作者: 空白少侠 | 来源:发表于2017-04-15 20:52 被阅读399次

预置一些散点 使用Tensorflow拟合并用matplotlib展示出来

#coding=utf-8
import tensorflow as tf
import numpy
import matplotlib.pyplot as plt
if __name__ == '__main__':
    rng = numpy.random

    # 学习速率 迭代次数 50次迭代输出
    learning_rate = 0.01
    training_epochs = 8000
    display_step = 50

    # 训练数据
    train_X = numpy.asarray(
        [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, 7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1])
    train_Y = numpy.asarray(
        [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, 2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3])
    n_samples = train_X.shape[0]

    # tf Graph Input
    X = tf.placeholder("float")
    Y = tf.placeholder("float")

    # 创建模型

    # 变量权重和偏置值
    W = tf.Variable(rng.randn(), name="weight")
    b = tf.Variable(rng.randn(), name="bias")

    # 构建线性模型
    activation = tf.add(tf.multiply(X, W), b)

    # 最小平方误差
    cost = tf.reduce_sum(tf.pow(activation - Y, 2)) / (2 * n_samples)  # L2 loss
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  # 随机梯度下降

    # 初始化变量
    init = tf.global_variables_initializer()

    # 启动模型
    with tf.Session() as sess:
        sess.run(init)

        # 训练
        for epoch in range(training_epochs):
            for (x, y) in zip(train_X, train_Y):
                sess.run(optimizer, feed_dict={X: x, Y: y})

            # 每display_step次输出查看
            if epoch % display_step == 0:
                print("Epoch:", '%04d' % (epoch + 1), "cost=","{:.9f}".format(sess.run(cost, feed_dict={X: train_X, Y: train_Y})), "W=", sess.run(W), "b=", sess.run(b))

        print("Optimization Finished!")
        print("cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}),"W=", sess.run(W), "b=", sess.run(b))

        # 展示
        plt.plot(train_X, train_Y, 'ro', label='Original data')
        plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
        plt.legend()
        plt.show()

相关文章

网友评论

      本文标题:使用Tensorflow拟合直线

      本文链接:https://www.haomeiwen.com/subject/rjloattx.html