这两天在学习 Deep Mind 视频,其中提到了手写的过程,让我不禁想起曾经用 c++ 手写深度学习框架的过程,如下是基于 tensorflow 和 python 的手写过程,实现一个简单的 Linear Regression 的模型。
初始化数据
通过 numpy
随机生成在某条直线周围的点,并使用 matplotlib
将其画出来,如下所示:
import numpy as np # linear algebra
import tensorflow as tf
import matplotlib
num_samples, w, b = 20, 0.5, 2.
xs = np.asarray(range(num_samples))
ys = np.asarray([x * w + b + np.random.normal() for x in range(num_samples)])
matplotlib.pyplot.plot(xs, ys, 'ro')
效果如下图:
初始数据
构建初始线性模型
构建一个简单的 w * x + b
的模型,其实现如下:
class Linear(object):
def __init__(self):
self.w = tf.get_variable('w', dtype=tf.float32, shape=[], initializer=tf.zeros_initializer())
self.b = tf.get_variable('b', dtype=tf.float32, shape=[], initializer=tf.zeros_initializer())
def __call__(self, x):
return self.w * x + self.b
# 输入参数
xtf = tf.placeholder(tf.float32, [num_samples], 'xs')
# 输出结果
ytf = tf.placeholder(tf.float32, [num_samples], 'ys')
# forward
model = Linear()
model_output = model(xtf)
手动计算 loss, 并更新参数
# backward
cov = tf.reduce_sum((xtf - tf.reduce_mean(xtf)) * (ytf - tf.reduce_mean(ytf)))
var = tf.reduce_sum(tf.square(xtf - tf.reduce_mean(xtf)))
w_hat = cov / var
b_hat = tf.reduce_mean(ytf) - w_hat * tf.reduce_mean(xtf)
# update
solve_w = model.w.assign(w_hat)
solve_b = model.b.assign(b_hat)
with tf.train.MonitoredSession() as sess:
sess.run([solve_w, solve_b], feed_dict={xtf: xs, ytf: ys})
preds = sess.run(model_output, feed_dict={xtf: xs, ytf: ys})
matplotlib.pyplot.plot(xs, preds)
matplotlib.pyplot.plot(xs, ys, 'ro')
效果如下图:
通过协方差计算w和b过程结果
通过 MSE 函数计算 Loss 并更新
通过 MSE 函数,及学习率,逐步逼近
#backward
loss = tf.losses.mean_squared_error(ytf, model_output)
grads = tf.gradients(loss, [model.w, model.b])
# update
lr = 0.001
update_w = tf.assign(model.w, model.w - lr * grads[0])
update_b = tf.assign(model.b, model.b - lr * grads[1])
update = tf.group(update_w, update_b)
# training
matplotlib.pyplot.plot(xs, ys, 'ro')
feed_dict = {xtf: xs, ytf: ys}
with tf.train.MonitoredSession() as sess:
for i in range(500):
sess.run(update, feed_dict=feed_dict)
if i in [1, 5, 25, 125, 499]:
preds = sess.run(model_output, feed_dict=feed_dict)
matplotlib.pyplot.plot(xs, preds, label = str(i))
matplotlib.pyplot.legend()
效果如下图:
迭代更新逼近的过程
网友评论