美文网首页我爱编程
使用tensorflow实现简单的线性回归

使用tensorflow实现简单的线性回归

作者: waitfun02 | 来源:发表于2018-03-10 23:55 被阅读0次

import tensorflowas tf

import  numpyas np

import  matplotlib.pyplotas plt

# 简单的y = x * 0.1 + 0.5的分布

num_ponits =1000

venter = []

for iin range(num_ponits):

x = np.random.normal(0.0,0.5)

y = x*0.1  +0.5 + np.random.normal(0.0,0.03)

venter.append([x,y])

x_data = [v[0]for vin venter]

y_data = [v[1]for vin venter]

#plt.scatter(x_data,y_data,c='r')

#plt.show()

#权重要给个随机值,不能为0,这里为1维矩阵,取值在-1~1之间

#w = tf.Variable(tf.random_uniform([1],-1.0,1.0),name="W")

w = tf.Variable(tf.random_uniform([1],-1.0,1.0),name="W")

b = tf.Variable(tf.zeros([1]),name="b")

y = w*x_data + b

#求预测值y和实际值y_data之间的平方的平均值作为损失

loss = tf.reduce_mean(tf.square(y-y_data),name="loss")

#采用梯度下降的方法 来优化参数

optimizer = tf.train.GradientDescentOptimizer(0.5)

#训练的过程就是最小化这个误差值

train = optimizer.minimize(loss,name="loss")

sess = tf.Session()

init = tf.global_variables_initializer()

sess.run(init)

for stepin range(20):

sess.run(train)

print("w=",sess.run(w),"b=",sess.run(b),"loss=",sess.run(loss))

plt.scatter(x_data,y_data,c='r')

plt.plot(x_data,sess.run(w)*x_data+sess.run(b))

plt.show()


对比公式y = x * 0.1 + 0.5,可以看出权重逐渐等于0.1,倾斜值逐渐等于0.5.

相关文章

网友评论

    本文标题:使用tensorflow实现简单的线性回归

    本文链接:https://www.haomeiwen.com/subject/pihifftx.html