导入包和数据;
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)
GAN的real和noise图像的输入:
def get_input(real_size,noise_size):
real_img = tf.placeholder(tf.float32,[None,real_size])
noise_img = tf.placeholder(tf.float32,[None,noise_size])
return real_img,noise_img
Generator
def get_generator(noise_img,u_units,out_dim,reuse=False,alpha=0.01):
with tf.variable_scope('generator',reuse=reuse):
hidden1 = tf.layers.dense(noise_img,u_units)
#leaky relu
hidden1 = tf.maximum(alpha*hidden1,hidden1)
hidden1 = tf.layers.dropout(hidden1,rate=0.8)
#输入为图像的像素大小
logits = tf.layers.dense(hidden1,out_dim)
outputs = tf.tanh(logits)
return logits,outputs
Discriminator
def get_discriminator(img,n_units,out_dim=1,reuse=False,alpha=0.01):
with tf.variable_scope('discriminator',reuse=reuse):
hidden1 = tf.layers.dense(img,n_units)
hidden1 = tf.maximum(alpha*hidden1,hidden1)
hidden1 = tf.layers.dropout(hidden1,rate=0.8)
logits = tf.layers.dense(hidden1,out_dim)
outputs = tf.sigmoid(logits)
return logits,outputs
基本参数配置
#real img大小
img_size = mnist.train.images[0].shape[0]
#Generator 输入的noise大小
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.001
learning_rate = 0.001
定义网络
tf.reset_default_graph()
real_img,noise_img = get_input(img_size,noise_size)
g_logits,g_outputs = get_generator(noise_img,u_units=g_units,out_dim=img_size)
d_logits,d_outputs = get_discriminator(real_img,d_units)
d_logits_fake,d_outputs_fake = get_discriminator(g_outputs,d_units,reuse=True)
#识别真实图片
d_loss_real = tf.reduce_mean\
(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits,
labels=tf.ones_like(d_logits)))
d_loss_fake = tf.reduce_mean\
(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real,d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)))
train_var = tf.trainable_variables()
#获的变量var.name.startswith var.name.startswith
# g_var = [var for var in train_var if var.name.startswith('generator')]
g_var = [var for var in train_var if var.name.startswith('generator')]
d_var = [var for var in train_var if var.name.startswith('discriminator')]
#优化器
d_train_op = tf.train.AdamOptimizer(learning_rate).minimize(d_loss,var_list=d_var)
g_train_op = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,var_list=g_var)
#训练
batch_size = 64
epochs = 200
n_samples = 25
samples = []
losses = []
saver = tf.train.Saver(var_list=g_var)
训练与保存
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for batch_i in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
batch_imgs = batch[0].reshape(batch_size,784)
batch_imgs = batch_imgs*2-1
batch_noise = np.random.uniform(-1,1,size=(batch_size,noise_size))
_ = sess.run(d_train_op,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
_ = sess.run(g_train_op,feed_dict={noise_img:batch_noise})
train_loss_d = sess.run(d_loss,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
train_loss_d_real = sess.run(d_loss_real,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
train_loss_d_fake = sess.run(d_loss_fake,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
train_loss_g = sess.run(g_loss,feed_dict={noise_img:batch_noise})
print('epoch{0}: loss_G{1} loss_d{2}'.format(e,train_loss_g,train_loss_d))
losses.append((train_loss_g,train_loss_d,train_loss_d_fake,train_loss_d_real))
sample_noise = np.random.uniform(-1,1,size=(n_samples,noise_size))
gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
samples.append(gen_samples)
saver.save(sess,'./generator.ckpt')
with open('train_samples.pkl','wb') as f:
pickle.dump(samples,f)
显示训练过程:
fig,ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0],label='train_loss_g')
plt.plot(losses.T[1],label='train_loss_d')
plt.plot(losses.T[2],label='train_loss_d_fake')
plt.plot(losses.T[2],label='train_loss_d_real')
plt.title('GAN')
plt.legend()
可视化结果:
def view_samples(epoch,samples):
fig,axes = plt.subplots(figsize=(7,7),nrows = 5,ncols = 5,sharey=True,sharex=True)
for ax,img in zip(axes.flatten(),samples[epoch][1]):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
im = ax.imshow(img.reshape((28,28)),cmap='Greys_r')
return fig,axes
_ = view_samples(-1,samples)
结果:迭代次数只有200次,效果不是很理想
image.png
附:保存模型的重新载入
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,'./log_google/generator.ckpt.index')
print('Load Over')
for e in range(epochs):
print(e)
sample_noise = np.random.uniform(-1, 1, size=(n_samples, noise_size))
gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
samples.append(gen_samples)
网友评论