美文网首页
实现一个简单的神经网络

实现一个简单的神经网络

作者: 我要当大佬 | 来源:发表于2017-11-07 10:33 被阅读0次

    import tensorflow as tf

    import numpy as np

    import matplotlib.pyplot as plt

    x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]

    noise = np.random.normal(0,0.02,x_data.shape)

    y_data = np.square(x_data)+noise

    x = tf.placeholder(tf.float32,[None,1])

    y = tf.placeholder(tf.float32,[None,1])

    #定义神经网络中间层

    Weight_L1 = tf.Variable(tf.random_normal([1,10]))

    baises_L1 = tf.Variable(tf.zeros([1,10]))

    Wx_plus_b_L1 = tf.matmul(x,Weight_L1)+baises_L1

    L1 = tf.nn.tanh(Wx_plus_b_L1)

    #定义输出层

    Weights_L2 = tf.Variable(tf.random_normal([10,1]))

    baises_L2 = tf.Variable(tf.zeros([1,1]))

    Wx_plus_b_L2 = tf.matmul(L1,Weights_L2)+baises_L2

    prediction = tf.nn.tanh(Wx_plus_b_L2)

    #二次代价函数

    loss = tf.reduce_mean(tf.square(y-prediction))

    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

    with tf.Session() as sess:

    init = tf.global_variables_initializer()

    sess.run(init)

    for _ in range(2000):

    sess.run(train_step,feed_dict={x:x_data,y:y_data})

    #获得预测值

    prediction_value = sess.run(prediction,feed_dict={x:x_data})

    plt.figure()

    plt.scatter(x_data,y_data)

    plt.plot(x_data,prediction_value,'r-',lw=5)

    plt.show()

    相关文章

      网友评论

          本文标题:实现一个简单的神经网络

          本文链接:https://www.haomeiwen.com/subject/rrvdmxtx.html