先上代码
from __future__import print_function
import tensorflowas tf
import numpyas np
import matplotlib.pyplotas plt
#添加神经网络层的方法,如果有激活函数那么就是返回隐藏层,反之返回输出层.
def add_layer(inputs,in_size,out_size,activtion_function=None):
Weights = tf.Variable(tf.random_normal([in_size,out_size])) ## 权重
biases=tf.Variable(tf.zeros([1,out_size])+0.1) ## 偏置
y=tf.matmul(inputs,Weights)+biases
if activtion_functionis None:
outputs = y
else:
outputs=activtion_function(y)
return outputs
#构造神经网络的输入的数据
x_data=np.linspace(-1,1,300)[:,np.newaxis]
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise
#为神经网络输入定义 占位符 ,通过placeholder来定义。shape 参数可选,但是定义后tf可以自动捕捉# 数据维度不一致的错误
xs=tf.placeholder(tf.float32,[None,1])
ys=tf.placeholder(tf.float32,[None,1])
#添加隐藏层
l1=add_layer(xs,1,10,activtion_function=tf.nn.relu)
#添加输出层
prediction=add_layer(l1,10,1,activtion_function=None)
#损失函数
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
#优化策略,最小花误差函数
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init = tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(1000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i %10 ==0:
# to see the step improvement
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value=sess.run(prediction,feed_dict={xs:x_data})
#plot the prediction
lines=ax.plot(x_data,prediction_value,'r-',lw=5)
plt.pause(0.1)
sess.run(tf.global_variables_initializer())
这个方法在我们训练神经网络时都需要加上,官方的解释是初始化模型参数。参考:https://blog.csdn.net/u012436149/article/details/78291545。具体来说就是初始化Variable变量。
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
梯度下降算法来最小化误差
添加隐藏层函数:
def add_layer(inputs,in_size,out_size,activtion_function=None):
Weights = tf.Variable(tf.random_normal([in_size,out_size])) ## 权重
biases=tf.Variable(tf.zeros([1,out_size])+0.1) ## 偏置
y=tf.matmul(inputs,Weights)+biases
if activtion_functionis None:
outputs = y
else:
outputs=activtion_function(y)
return outputs
网友评论