目标
给定一批由 y = 3*x + 2
生成的数据集(x, y)
建立线性回归模型 h(x) = w*x + b
预测出 w = 3
和 b = 2
实现
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
# generate training set, y = 3x + 2
training_set_size = 1000
true_w = 3
true_b = 2
training_set_x = 10 * np.random.random_sample(training_set_size)
training_set_y = true_w * training_set_x + true_b
# step 1: model
w = tf.Variable(0, dtype=tf.float64, name="w")
b = tf.Variable(0, dtype=tf.float64, name="b")
x = tf.placeholder(tf.float64, name="x")
y = tf.placeholder(tf.float64, name="y")
hypothesis_y = w * x + b
# step 2: loss function
squared_deltas = tf.square(hypothesis_y - y)
loss = tf.reduce_mean(squared_deltas)
# step 3: gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(loss)
# initialization
init = tf.global_variables_initializer()
training_set = {x: training_set_x, y: training_set_y}
# run
with tf.Session() as sess:
sess.run(init)
print("starting:", "loss = ", sess.run(loss, training_set))
for i in range(1, 1000):
sess.run(train, training_set)
if i % 100 == 0:
print("training:", "W = ", sess.run(w), "b = ", sess.run(b), "loss = ", sess.run(loss, training_set))
print("result:", "W = ", sess.run(w), "b = ", sess.run(b), "loss = ", sess.run(loss, training_set))
print("expect:", "W = ", true_w, "b = ", true_b)
运行结果
starting: loss = 370.048244221
training: W = 3.13735366368 b = 1.07931468385 loss = 0.216286492577
training: W = 3.08327147643 b = 1.44182904521 loss = 0.0794952961743
training: W = 3.05048382839 b = 1.66160553524 loss = 0.0292182005384
training: W = 3.03060612155 b = 1.7948463409 loss = 0.010739040972
training: W = 3.01855514342 b = 1.87562437266 loss = 0.00394709458057
training: W = 3.01124916618 b = 1.92459653538 loss = 0.00145073993744
training: W = 3.00681987398 b = 1.95428620061 loss = 0.000533214070029
training: W = 3.0041345892 b = 1.97228573693 loss = 0.000195980848902
training: W = 3.00250661932 b = 1.983198063 loss = 7.203203234e-05
result: W = 3.0015272771 b = 1.98976262036 loss = 2.67414271852e-05
expect: W = 3 b = 2
网友评论