美文网首页我爱编程
利用 TensorFlow 和 MNIST 数据集演示 GAN

利用 TensorFlow 和 MNIST 数据集演示 GAN

作者: 拓季 | 来源:发表于2018-03-08 15:16 被阅读0次

    自打关注深度学习这个领域就不时的看到和 Generative Adversarial Network, GAN 相关的东西,也一直非常好奇这个被 LeCun 称为深度学习近年来最大的突破的东西到底是什么样子的。正好在 Udacity 的课堂里遇到了,在完成了通过 GAN 来完成人脸生成的项目后,在这里做一个总结,加深一下对于 GAN 这个网络的理解。为了便于本地试验,这里展示的是利用 MNIST 数据集来训练一个简单的 GAN 来生成手写数字的过程。注意文中代码和示例图片来自 Udacity 深度学习纳米学位课程,版权归 Udacity 所有。

    深度神经网络最令人诟病一点就在于其决策过程的不可解释性,你无从知道网络中的单元提取了哪些特征来完成了一项分类或识别任务。比如在图片识别任务中,即便你可以提取隐藏层的 feature map 来可视化出来相应层的情况,其图像在人类看来是抽象而诡异甚至有些惊悚的。这一点其实在我看来是十分正常的,也不应该像很多媒体的解读方式那样过分的夸大,事实上,人脑的加工过程有谁可以可视化出来呢?只不过我们对于人类行为的可预测性是有把握的,所以不像对于新生技术那样容易催生恐惧。

    而 GAN 最为聪明之处在于既然人类无法理解网络内部的生成过程,索性不用人脑和人类对于图像的理解方式去理解中间过程,而是用另一个类似结构的神经网络,二者的相互理解过程也就是对抗 Adversarial 的过程。其实现的大致思路是:

    • 作为生成器的一个典型代表,GAN 的一个典型应用是通过模型来生成类似已有数据集的图片来实现数据扩增,因此可以首先建立一个通过多层神经网络实现的生成器,其主要作用是通过对于符合一定分布规律的原始数据进行处理,进而得到一个符合另一特定分布情况的结果图像。这里要求这个网络至少包含一个隐藏层,否则网络就不具有足够的学习和泛化能力,这个网络在 GAN 中被称为生成器 Generator。例如在下面的示例图片中,生成器的输入是符合某个分布特征的随机数字:在后续的代码示例中采用的是 (-1, 1) 之间的均匀分布

    • 在获得了生成器之后,还要建立一个类似结构的可以完成图像识别任务的分类器,其特殊之处在于这个网络的输出层只对输入是来自原始数据集还是由生成器网络生成的结果做一个真假判断,这个网络在 GAN 中称为识别器 Discriminator

    High level overview of GAN with MNIST

    在看到代码之前我一直以为 GAN 的实现会比较复杂,但真正看到代码之后就像看到 E = mc2 一样,发现其是如此的简洁,优雅,直观,不得不佩服 Ian Goodfellow 强大的思路。闲话到此为止,网络架构和实现代码如下:

    Network Architecture
    %matplotlib inline
    import pickle as pkl
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    # load data
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data')
    
    # define the model input for both Generator and Discirminator
    def model_inputs(real_dim, z_dim):
        inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real') 
        inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
        
        return inputs_real, inputs_z
    
    # define the Generator
    def generator(z, out_dim, n_units=128, reuse=False, alpha=0.01):
        with tf.variable_scope('generator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(z, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)
            
            # Logits and tanh output
            logits = tf.layers.dense(h1, out_dim, activation=None)
            out = tf.tanh(logits)
            
            return out
    
    # define the Discriminator
    def discriminator(x, n_units=128, reuse=False, alpha=0.01):
        with tf.variable_scope('discriminator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(x, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)
            
            logits = tf.layers.dense(h1, 1, activation=None)
            out = tf.sigmoid(logits)
            
            return out, logits
    

    这里之所以要定义这个 variable_scope 是由于在后续的训练中,需要分别更新生成器和判别器的参数,为了提取参数而特别设置的。另外值得注意的是,激活函数需要采用 Leaky ReLU 来保证梯度可以从判别器传回到生成器。

    # build the network
    tf.reset_default_graph()
    # Create our input placeholders
    input_real, input_z = model_inputs(input_size, z_size)
    
    # Build the model
    g_model = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
    # g_model is the generator output
    
    d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)
    
    # Calculate losses
    d_loss_real = tf.reduce_mean(
                      tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                              labels=tf.ones_like(d_logits_real) * (1 - smooth)))
    d_loss_fake = tf.reduce_mean(
                      tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                              labels=tf.zeros_like(d_logits_real)))
    d_loss = d_loss_real + d_loss_fake
    
    g_loss = tf.reduce_mean(
                 tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                         labels=tf.ones_like(d_logits_fake)))
    

    在这里新引入的一个操作是 label smoothing,其目的在于适度的放低要求以促进收敛。而针对损失函数这部分,由于希望判别器将真实数据识别为 1, 而将生成器生成的数据识别为 0,因此需要分别计算这两部分的损失函数。

    # Optimizers
    learning_rate = 0.002
    
    # Get the trainable_variables, split into G and D parts
    t_vars = tf.trainable_variables()
    g_vars = [var for var in t_vars if var.name.startswith('generator')]
    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
    
    d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
    

    这一段代码非常重要,正式因为选择了间歇性的训练才使得网络的对抗得以实现。

    # Size of input image to discriminator
    input_size = 784
    # Size of latent vector to generator
    z_size = 100
    # Sizes of hidden layers in generator and discriminator
    g_hidden_size = 128
    d_hidden_size = 128
    # Leak factor for leaky ReLU
    alpha = 0.01
    # Smoothing 
    smooth = 0.1
    

    下面代码部分为比较常见的训练代码结构:

    batch_size = 100
    epochs = 100
    samples = []
    losses = []
    # Only save generator variables
    saver = tf.train.Saver(var_list=g_vars)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for e in range(epochs):
            for ii in range(mnist.train.num_examples//batch_size):
                batch = mnist.train.next_batch(batch_size)
                
                # Get images, reshape and rescale to pass to D
                batch_images = batch[0].reshape((batch_size, 784))
                batch_images = batch_images*2 - 1
                
                # Sample random noise for G
                batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
                
                # Run optimizers
                _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
                _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
            
            # At the end of each epoch, get the losses and print them out
            train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
            train_loss_g = g_loss.eval({input_z: batch_z})
                
            print("Epoch {}/{}...".format(e+1, epochs),
                  "Discriminator Loss: {:.4f}...".format(train_loss_d),
                  "Generator Loss: {:.4f}".format(train_loss_g))    
            # Save losses to view after training
            losses.append((train_loss_d, train_loss_g))
            
            # Sample from generator as we're training for viewing afterwards
            sample_z = np.random.uniform(-1, 1, size=(16, z_size))
            gen_samples = sess.run(
                           generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                           feed_dict={input_z: sample_z})
            samples.append(gen_samples)
            saver.save(sess, './checkpoints/generator.ckpt')
    
    # Save training generator samples
    with open('train_samples.pkl', 'wb') as f:
        pkl.dump(samples, f)
    

    为了监控训练,可以提取训练过程中的参数来识别训练结果。实际上在学习过程中可以发现 GAN 的训练对于超参数的选择十分敏感,并且在后续的 DCGAN 学习中,作者们甚至通过调整 Adam 中的指数加权平均参数 beta1 来实现较好的训练效果。Ian Goodfellow 在 Andrew Ng 的访谈里也提到自己现在 40% 的时间话在研究如何 Stablize GAN,当时没理解是什么意思,直到自己训练了 DCGAN 之后才知道原来 GAN 的训练对于超参数是如此的敏感。

    def view_samples(epoch, samples):
        fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
        for ax, img in zip(axes.flatten(), samples[epoch]):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        
        return fig, axes
    
    rows, cols = 10, 6
    fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)
    
    for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
        for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
            ax.imshow(img.reshape((28,28)), cmap='Greys_r')
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
    
    Generated result as the training goes

    参考阅读

    1. Tips and tricks to make GANs work

    2. Generative Adversarial Networks for beginners

    相关文章

      网友评论

        本文标题:利用 TensorFlow 和 MNIST 数据集演示 GAN

        本文链接:https://www.haomeiwen.com/subject/axgefftx.html