使用GAN生成手写数字

作者: 上行彩虹人 | 来源:发表于2018-11-18 20:56 被阅读16次

    导入包和数据;

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    import pickle
    
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)
    

    GAN的real和noise图像的输入:

    def get_input(real_size,noise_size):
        real_img = tf.placeholder(tf.float32,[None,real_size])
        noise_img = tf.placeholder(tf.float32,[None,noise_size])
        return real_img,noise_img
    

    Generator

    def get_generator(noise_img,u_units,out_dim,reuse=False,alpha=0.01):
        with tf.variable_scope('generator',reuse=reuse):
            hidden1 = tf.layers.dense(noise_img,u_units)
            #leaky relu
            hidden1 = tf.maximum(alpha*hidden1,hidden1)
            hidden1 = tf.layers.dropout(hidden1,rate=0.8)
            #输入为图像的像素大小
            logits = tf.layers.dense(hidden1,out_dim)
            outputs = tf.tanh(logits)
            return logits,outputs
    

    Discriminator

    def get_discriminator(img,n_units,out_dim=1,reuse=False,alpha=0.01):
        with tf.variable_scope('discriminator',reuse=reuse):
            hidden1 = tf.layers.dense(img,n_units)
            hidden1 = tf.maximum(alpha*hidden1,hidden1)
            hidden1 = tf.layers.dropout(hidden1,rate=0.8)
    
            logits = tf.layers.dense(hidden1,out_dim)
            outputs = tf.sigmoid(logits)
    
            return logits,outputs
    

    基本参数配置

    #real img大小
    img_size = mnist.train.images[0].shape[0]
    #Generator 输入的noise大小
    noise_size = 100
    g_units = 128
    d_units = 128
    alpha = 0.001
    learning_rate = 0.001
    

    定义网络

    tf.reset_default_graph()
    
    real_img,noise_img = get_input(img_size,noise_size)
    
    g_logits,g_outputs = get_generator(noise_img,u_units=g_units,out_dim=img_size)
    
    d_logits,d_outputs = get_discriminator(real_img,d_units)
    d_logits_fake,d_outputs_fake = get_discriminator(g_outputs,d_units,reuse=True)
    
    
    #识别真实图片
    d_loss_real = tf.reduce_mean\
        (tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits,
                                                                         labels=tf.ones_like(d_logits)))
    
    d_loss_fake = tf.reduce_mean\
        (tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
    
    d_loss = tf.add(d_loss_real,d_loss_fake)
    
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)))
    
    train_var = tf.trainable_variables()
    
    #获的变量var.name.startswith          var.name.startswith
    # g_var = [var for var in train_var if var.name.startswith('generator')]
    g_var = [var for var in train_var if var.name.startswith('generator')]
    d_var = [var for var in train_var if var.name.startswith('discriminator')]
    
    #优化器
    d_train_op = tf.train.AdamOptimizer(learning_rate).minimize(d_loss,var_list=d_var)
    g_train_op = tf.train.AdamOptimizer(learning_rate).minimize(g_loss,var_list=g_var)
    
    #训练
    batch_size = 64
    epochs = 200
    n_samples = 25
    
    samples = []
    losses = []
    
    saver = tf.train.Saver(var_list=g_var)
    

    训练与保存

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for e in range(epochs):
            for batch_i in range(mnist.train.num_examples//batch_size):
                batch = mnist.train.next_batch(batch_size)
                batch_imgs = batch[0].reshape(batch_size,784)
                batch_imgs = batch_imgs*2-1
    
                batch_noise = np.random.uniform(-1,1,size=(batch_size,noise_size))
    
                _ = sess.run(d_train_op,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
                _ = sess.run(g_train_op,feed_dict={noise_img:batch_noise})
    
            train_loss_d = sess.run(d_loss,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
            train_loss_d_real = sess.run(d_loss_real,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
            train_loss_d_fake = sess.run(d_loss_fake,feed_dict={real_img:batch_imgs,noise_img:batch_noise})
            train_loss_g = sess.run(g_loss,feed_dict={noise_img:batch_noise})
    
            print('epoch{0}: loss_G{1} loss_d{2}'.format(e,train_loss_g,train_loss_d))
    
            losses.append((train_loss_g,train_loss_d,train_loss_d_fake,train_loss_d_real))
    
            sample_noise = np.random.uniform(-1,1,size=(n_samples,noise_size))
            gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
            samples.append(gen_samples)
    
            saver.save(sess,'./generator.ckpt')
    
    
    with open('train_samples.pkl','wb') as f:
        pickle.dump(samples,f)
    

    显示训练过程:

    fig,ax = plt.subplots(figsize=(20,7))
    losses = np.array(losses)
    plt.plot(losses.T[0],label='train_loss_g')
    plt.plot(losses.T[1],label='train_loss_d')
    plt.plot(losses.T[2],label='train_loss_d_fake')
    plt.plot(losses.T[2],label='train_loss_d_real')
    plt.title('GAN')
    plt.legend()
    

    可视化结果:

    def view_samples(epoch,samples):
        fig,axes = plt.subplots(figsize=(7,7),nrows = 5,ncols = 5,sharey=True,sharex=True)
        for ax,img in zip(axes.flatten(),samples[epoch][1]):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            im = ax.imshow(img.reshape((28,28)),cmap='Greys_r')
    
        return fig,axes
    
    _ = view_samples(-1,samples)
    

    结果:迭代次数只有200次,效果不是很理想


    image.png

    附:保存模型的重新载入

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess,'./log_google/generator.ckpt.index')
        print('Load Over')
        for e in range(epochs):
            print(e)
            sample_noise = np.random.uniform(-1, 1, size=(n_samples, noise_size))
            gen_samples = sess.run(get_generator(noise_img,g_units,img_size,reuse=True),feed_dict={noise_img:sample_noise})
            samples.append(gen_samples)
    

    相关文章

      网友评论

        本文标题:使用GAN生成手写数字

        本文链接:https://www.haomeiwen.com/subject/iffnfqtx.html