自编码器是神经网络的一种,属于无监督学习,从输入层到隐层称为编码过程,从隐层到输出层称为解码过程。
自编码器是一种有损压缩,可以通过使得损失函数最小,来实现输出值X’ 近似于自身X的值。
使用jupyter notebook:
import keras
from keras.layers import Dense, Input
from keras.datasets import mnist
from keras.models import Model
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
print(x_train.shape)
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)
print(x_train.shape)
(60000, 28, 28)
(60000, 784)
加入一些噪音:
#add random noise
x_train_noise = x_train + 0.3 * np.random.normal(loc=0., scale=1., size=x_train.shape)
x_test_noise = x_test + 0.3 * np.random.normal(loc=0, scale=1, size=x_test.shape)
x_train_noise = np.clip(x_train_noise, 0., 1.)
x_test_noise = np.clip(x_test_noise, 0, 1.)
建立模型:
#build autoencoder model
input_img = Input(shape=(28*28,))
encoded = Dense(500, activation='relu')(input_img)
decoded = Dense(784, activation='sigmoid')(encoded)
autoencoder = Model(inputs=[input_img], outputs=[decoded])
用adam解析,显示现在的数据概况
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.summary()
Layer (type) Output Shape Param #
input_2 (InputLayer) (None, 784) 0
dense_3 (Dense) (None, 500) 392500
dense_4 (Dense) (None, 784) 392784
Total params: 785,284
Trainable params: 785,284
Non-trainable params: 0
用噪音和原始训练集进行训练,用10000个样本验证集进行验证:
autoencoder.fit(x_train_noise, x_train, epochs=20, batch_size=128, verbose=1, validation_data=(x_test, x_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 14s 231us/step - loss: 0.1434 - val_loss: 0.2226
···
···
Epoch 20/20
60000/60000 [==============================] - 11s 191us/step - loss: 0.0791 - val_loss: 0.0730
把结果绘制出来:
%matplotlib inline
import matplotlib.pyplot as plt
#decoded test images
decoded_img = autoencoder.predict(x_test_noise)
# 显示前10个
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# noise
ax = plt.subplot(3, n, i + 1)
plt.imshow(x_test_noise[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# predict
ax = plt.subplot(3, n, i + 1 + n)
plt.imshow(decoded_img[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
# original
ax = plt.subplot(3, n, i + 1 + 2 * n)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.show()
网友评论