使用GAN生成手写数字

作者: 上行彩虹人 | 来源:发表于2018-11-18 20:56 被阅读16次

导入包和数据;

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)

GAN的real和noise图像的输入:

def get_input(real_size,noise_size):
    real_img = tf.placeholder(tf.float32,[None,real_size])
    noise_img = tf.placeholder(tf.float32,[None,noise_size])
    return real_img,noise_img

Generator

def get_generator(noise_img,u_units,out_dim,reuse=False,alpha=0.01):
    with tf.variable_scope('generator',reuse=reuse):
        hidden1 = tf.layers.dense(noise_img,u_units)
        #leaky relu
        hidden1 = tf.maximum(alpha*hidden1,hidden1)
        hidden1 = tf.layers.dropout(hidden1,rate=0.8)
        #输入为图像的像素大小
        logits = tf.layers.dense(hidden1,out_dim)
        outputs = tf.tanh(logits)
        return logits,outputs

Discriminator

def get_discriminator(img,n_units,out_dim=1,reuse=False,alpha=0.01):
    with tf.variable_scope('discriminator',reuse=reuse):
        hidden1 = tf.layers.dense(img,n_units)
        hidden1 = tf.maximum(alpha*hidden1,hidden1)
        hidden1 = tf.layers.dropout(hidden1,rate=0.8)

        logits = tf.layers.dense(hidden1,out_dim)
        outputs = tf.sigmoid(logits)

        return logits,outputs

基本参数配置

#real img大小
img_size = mnist.train.images[0].shape[0]
#Generator 输入的noise大小
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.001
learning_rate = 0.001

定义网络

tf.reset_default_graph()

real_img,noise_img = get_input(img_size,noise_size)

g_logits,g_outputs = get_generator(noise_img,u_units=g_units,out_dim=img_size)

d_logits,d_outputs = get_discriminator(real_img,d_units)
d_logits_fake,d_outputs_fake = get_discriminator(g_outputs,d_units,reuse=True)


#识别真实图片
d_loss_real = tf.reduce_mean\
    (tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits,
                                                                     labels=tf.ones_like(d_logits)))

d_loss_fake = tf.reduce_mean\
    (tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                     labels=tf.zeros_like(d_logits_fake)))

d_loss = tf.add(d_loss_real,d_loss_fake)

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)))

train_var = tf.trainable_variables()

#获的变量var.name.startswith          var.name.startswith
# g_var = [var for var in train_var if var.name.startswith('generator')]
g_var = [var for var in train_var if var.name.startswith('generator')]
d_var = [var for var in train_var if var.name.startswith('discriminator')]

#优化器
d_train_op = tf.train.AdamOptimizer(learning_rate).minimize(d_loss,var_list=d_var)
g_train_op = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,var_list=g_var)

#训练
batch_size = 64
epochs = 200
n_samples = 25

samples = []
losses = []

saver = tf.train.Saver(var_list=g_var)

训练与保存

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for batch_i in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            batch_imgs = batch[0].reshape(batch_size,784)
            batch_imgs = batch_imgs*2-1

            batch_noise = np.random.uniform(-1,1,size=(batch_size,noise_size))

            _ = sess.run(d_train_op,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
            _ = sess.run(g_train_op,feed_dict={noise_img:batch_noise})

        train_loss_d = sess.run(d_loss,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
        train_loss_d_real = sess.run(d_loss_real,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
        train_loss_d_fake = sess.run(d_loss_fake,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
        train_loss_g = sess.run(g_loss,feed_dict={noise_img:batch_noise})

        print('epoch{0}: loss_G{1} loss_d{2}'.format(e,train_loss_g,train_loss_d))

        losses.append((train_loss_g,train_loss_d,train_loss_d_fake,train_loss_d_real))

        sample_noise = np.random.uniform(-1,1,size=(n_samples,noise_size))
        gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
        samples.append(gen_samples)

        saver.save(sess,'./generator.ckpt')


with open('train_samples.pkl','wb') as f:
    pickle.dump(samples,f)

显示训练过程:

fig,ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0],label='train_loss_g')
plt.plot(losses.T[1],label='train_loss_d')
plt.plot(losses.T[2],label='train_loss_d_fake')
plt.plot(losses.T[2],label='train_loss_d_real')
plt.title('GAN')
plt.legend()

可视化结果:

def view_samples(epoch,samples):
    fig,axes = plt.subplots(figsize=(7,7),nrows = 5,ncols = 5,sharey=True,sharex=True)
    for ax,img in zip(axes.flatten(),samples[epoch][1]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)),cmap='Greys_r')

    return fig,axes

_ = view_samples(-1,samples)

结果:迭代次数只有200次,效果不是很理想


image.png

附:保存模型的重新载入

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess,'./log_google/generator.ckpt.index')
    print('Load Over')
    for e in range(epochs):
        print(e)
        sample_noise = np.random.uniform(-1, 1, size=(n_samples, noise_size))
        gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
        samples.append(gen_samples)

相关文章

网友评论

    本文标题:使用GAN生成手写数字

    本文链接:https://www.haomeiwen.com/subject/iffnfqtx.html