【网络设计】
采用全连接网络:
3层编码,784->256->128
3层解码,128->256->784
输入:mnist手写图片
输出:由网络还原出来的图片
目标:还原度越高越好
因此我们可以总结出,最简单的Auto-encoder和decoder其实就是特殊结构的全连接神经网络
【代码展示】
#定义数据
mnist = input_data.read_data_sets('./mnist', one_hot=True)
n_input=784
n_hidden_1=256
n_hidden_2=128
#定义批个数和学习速率,这些决定了学习成果
batch_size=100
lr=0.001
training_epoches=200
display_epoches=10
total_batch=mnist.count()/batch_size
#输入,一个batch的图片
tf_x=tf.placeholder(tf.float32,shape=[None,28*28])
examples_to_show=7
#定义网络参数
weights={
'encoder_w1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
'encoder_w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
'decoder_w1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
'decoder_w2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
biases={
'encoder_b1':tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2':tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b1':tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
#定义网络的运算和连接方式
def encoder(x):
layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_w1']),biases['encoder_b1']))
layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_w2']),biases['encoder_b2']))
return layer_2
def decoder(x):
layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_w1']),biases['decoder_b1']))
layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_w2']),biases['decoder_b2']))
return layer_2
encoder_op=encoder(tf_x)
decoder_op=decoder(encoder_op)
y_pred=decoder_op
y_true=tf_x
#定义学习方式
cost=tf.reduce_mean(tf.pow(y_true-y_pred,2))
optimizer=tf.train.AdamOptimizer(lr).minimize(cost)
init=tf.initialize_all_variables()
#训练
with tf.Session()as sess:
sess.run(init)
total_batch
for i in range(training_epoches):
for j in range(total_batch):
batch_x, batch_y = mnist.train.nextbatch(batch_size)
_,c=sess.run([cost,optimizer],feed_dict={tf_x:batch_x})
if(j%display_epoches==0):
print("Epoch:%04d"%(j+1),"cost=","{:,%.9f}".format(c))
print("Optimize Finished!")
encode_decode=sess.run(y_pred,feed_dict={tf_x:mnist.test.images[:examples_to_show]})
f,a=plt.subplots(2,10,figsize=(10,2))
for i in range(examples_to_show):
a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))
plt.show()
【注意】
1、采用AdamOptimizer,效果最好
2、解码和编码网络架构是对称的
3、learningRate(lr)是个很重要的参数
网友评论