简单易用的轻量级生成对抗网络工具库:TFGAN

作者: Daisy丶 | 来源:发表于2018-11-19 15:52 被阅读1次

    TFGAN是谷歌开源的一个轻量级生成对抗网络(GAN)工具库,它为开发者轻松训练 GAN 提供了基础条件,提供经过完整测试的损失函数和评估指标,同时提供易于使用的范例,这些范例展示了 TFGAN 的表达能力和灵活性。这个库被包含在了TensorFlow contrib中,可以直接通过tf来进行使用,本文通过一个简单的unconditional gan模型在MNIST数据集上进行演示。

    Githubhttps://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan

    环境

    • Python 3.6
    • Tensorflow-gpu 1.8.0

    GAN

    生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

    GAN

    在训练的过程中固定一方,更新另一方的网络权重,交替迭代,在这个过程中,双方都极力优化自己的网络,从而形成竞争对抗,直到双方达到一个动态的平衡(纳什均衡),此时生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。
    当固定生成网络 G 的时候,对于判别网络 D 的优化,可以这样理解:输入来自于真实数据,D 优化网络结构使自己输出 1,输入来自于生成数据,D 优化网络结构使自己输出 0;当固定判别网络 D 的时候,G 优化自己的网络使自己输出尽可能和真实数据一样的样本,并且使得生成的样本经过 D 的判别之后,D 输出高概率。

    上述过程可以表述为如下公式:

    TFGAN

    TFGAN中的训练通常包括以下步骤:

    1.指定网络的输入。
    2.使用GANModel设置生成器和鉴别器。
    3.使用GANLoss指定损失。
    4.使用GANTrainOps创建训练操作。
    5.开始训练模型。

    数据

    首先读入MNIST数据作为输入数据x,如下所示:

    import tensorflow as tf
    import tensorflow.contrib.gan as tfgan
    import tensorflow.contrib.layers as layers
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    def provide_data(batch_size, num_threads=1):
        file = "MNIST"
        # range 0~1
        mnist = input_data.read_data_sets(file, one_hot=True)
    
        train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
        train_labels = mnist.train.labels
    
        # transfer to -1~1
        train_data = (tf.to_float(train_data) - 128.0) / 128.0
    
        # Creates a QueueRunner for the pre-fetching operation.
        input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
        images, labels = tf.train.batch(
                input_queue,
                batch_size=batch_size,
                num_threads=num_threads,
                capacity=5 * batch_size)
    
        return images, labels
    

    生成随机噪声作为输入数据z,如下所示:

        images, _ = provide_data(batch_size, num_threads=2)
        noise = tf.random_normal([batch_size, 64])
    

    定义模型

    首先我们需要定义生成器(generator)和鉴别器(discriminator)。

    generator定义如下所示,将一个一维的随机噪声通过反卷积生成通道数为1的图片数据,使用tanh是为了保持生成数据的范围与输入数据一致:

    def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
        """Simple generator to produce MNIST images.
    
        Args:
            noise: A single Tensor representing noise.
            weight_decay: The value of the l2 weight decay.
            is_training: If `True`, batch norm uses batch statistics. If `False`, batch
                norm uses the exponential moving average collected from population 
                statistics.
    
        Returns:
            A generated image in the range [-1, 1].
        """
        with tf.contrib.framework.arg_scope(
            [layers.fully_connected, layers.conv2d_transpose],
            activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
            weights_regularizer=layers.l2_regularizer(weight_decay)):
            with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                            zero_debias_moving_mean=True):
    
                net = layers.fully_connected(noise, 1024)
                net = layers.fully_connected(net, 7 * 7 * 128)
                net = tf.reshape(net, [-1, 7, 7, 128])
                net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
                net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
                # Make sure that generator output is in the same range as `inputs`
                # ie [-1, 1].
                net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
    
                return net
    

    discriminator的定义如下,是一个比较简单的二分类网络,用来判断输入的数据是生成的还是真实的:

    def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
                         is_training=True):
        """Discriminator network on MNIST digits.
    
        Args:
            img: Real or generated MNIST digits. Should be in the range [-1, 1].
            unused_conditioning: The TFGAN API can help with conditional GANs, which
                would require extra `condition` information to both the generator and the
                discriminator. Since this example is not conditional, we do not use this
                argument.
            weight_decay: The L2 weight decay.
            is_training: If `True`, batch norm uses batch statistics. If `False`, batch
                norm uses the exponential moving average collected from population 
                statistics.
    
        Returns:
            Logits for the probability that the image is real.
        """
        with tf.contrib.framework.arg_scope(
            [layers.conv2d, layers.fully_connected],
            activation_fn=tf.nn.relu, normalizer_fn=None,
            weights_regularizer=layers.l2_regularizer(weight_decay),
            biases_regularizer=layers.l2_regularizer(weight_decay)):
    
            net = layers.conv2d(img, 64, [4, 4], stride=2)
            net = layers.conv2d(net, 128, [4, 4], stride=2)
            net = layers.flatten(net)
    
            with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
                net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
    
            return layers.linear(net, 1)
    

    然后使用TFGAN定义一个GAN模型:

        with tf.name_scope('model'):
            # Build the generator and discriminator.
            gan_model = tfgan.gan_model(
                generator_fn=unconditional_generator,  # you define 
                discriminator_fn=unconditional_discriminator,  # you define
                real_data=images,
                generator_inputs=noise)
    

    设置损失函数

    使用TFGAN自带的损失函数配置模型,如下所示:

        with tf.name_scope('loss'):
            # Build the GAN loss.
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
                discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
                gradient_penalty_weight=1.0,
                add_summaries=True)
    
    

    同时可以进行自定义的损失函数,损失函数接收一个gan_model,对gan_model的output进行损失计算。如下所示:

    def silly_custom_generator_loss(gan_model, add_summaries=False):
        return tf.reduce_mean(gan_model.discriminator_gen_outputs)
    
    def silly_custom_discriminator_loss(gan_model, add_summaries=False):
        return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -
                tf.reduce_mean(gan_model.discriminator_real_outputs))
    
    custom_gan_loss = tfgan.gan_loss(
        gan_model,
        generator_loss_fn=silly_custom_generator_loss,
        discriminator_loss_fn=silly_custom_discriminator_loss)
    

    配置训练操作

    接下来使用TFGAN来配置训练操作,制定模型、损失、优化器、训练率等参数。关于check_for_unused_update_ops这个参数,由于batch_norm层的原因,如果设置为True会导致更新参数检查不一致,因此需要设置为False。

        with tf.name_scope('train'):
            # Create the train ops, which calculate gradients and apply updates to weights.
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
                discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
                check_for_unused_update_ops=False,
                summarize_gradients=True,
                aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
    

    训练模型

    通过TFGAN指定训练操作与保存的文件夹,就可以很容易的开始训练过程。如下所示:

        # Run the train ops in the alternating training scheme.
        tfgan.gan_train(
            train_ops,
            hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
            logdir=train_log_dir,
            save_summaries_steps=10)
    

    可视化评估

    def test(eval_dir, checkpoint_dir):
        tf.reset_default_graph()
    
        if not tf.gfile.Exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
    
        random_inputs = tf.random_normal([100, 64])
    
        with tf.variable_scope('Generator'):
            images = unconditional_generator(random_inputs, is_training=False)
    
        reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
        uint8_images = float_image_to_uint8(reshaped_images)
    
        image_write_ops = tf.write_file(
              '%s/%s' % (eval_dir, 'unconditional_gan.png'),
              tf.image.encode_png(uint8_images[0]))
    
        tf.contrib.training.evaluate_repeatedly(
                checkpoint_dir,
                eval_ops=image_write_ops,
                hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
                max_number_of_evaluations=1)
    

    不同Epoch阶段生成的效果如下:

    Epoch:400 Epoch:1000 Epoch:2000 Epoch:3000 Epoch:4000 Epoch:5000 Loss

    PS:
    在实验中发现生成器的学习率对生成效果有着很大的影响,最初参考官方文档使用1e-3的学习率,发现在Epoch:200左右生成了较为模糊的数字图片,进一步训练生成图片反而全是黑色背景,参考Loss变化发现整个训练过程非常不稳定。后面改用1e-4的学习率,才达到当前的生成效果。

    完整的unconditional gan代码如下所示,也可以参考官方的tutorial

    import tensorflow as tf
    import tensorflow.contrib.gan as tfgan
    import tensorflow.contrib.layers as layers
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    def float_image_to_uint8(image):
        """Convert float image in [-1, 1) to [0, 255] uint8.
        Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
        Args:
            image: An image tensor. Values should be in [-1, 1).
        Returns:
            Input image cast to uint8 and with integer values in [0, 255].
        """
        image = (image * 128.0) + 128.0
    
        return tf.cast(image, tf.uint8)
    
    
    def provide_data(batch_size, num_threads=1):
        file = "MNIST"
        # range 0~1
        mnist = input_data.read_data_sets(file, one_hot=True)
    
        train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
        train_labels = mnist.train.labels
    
        # transfer to -1~1
        train_data = (tf.to_float(train_data) - 128.0) / 128.0
    
        # Creates a QueueRunner for the pre-fetching operation.
        input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
        images, labels = tf.train.batch(
                input_queue,
                batch_size=batch_size,
                num_threads=num_threads,
                capacity=5 * batch_size)
    
        return images, labels
    
    
    def unconditional_generator(noise, weight_decay=2.5e-5, is_training=True):
        """Simple generator to produce MNIST images.
    
        Args:
            noise: A single Tensor representing noise.
            weight_decay: The value of the l2 weight decay.
            is_training: If `True`, batch norm uses batch statistics. If `False`, batch
                norm uses the exponential moving average collected from population 
                statistics.
    
        Returns:
            A generated image in the range [-1, 1].
        """
        with tf.contrib.framework.arg_scope(
            [layers.fully_connected, layers.conv2d_transpose],
            activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
            weights_regularizer=layers.l2_regularizer(weight_decay)):
            with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                            zero_debias_moving_mean=True):
    
                net = layers.fully_connected(noise, 1024)
                net = layers.fully_connected(net, 7 * 7 * 128)
                net = tf.reshape(net, [-1, 7, 7, 128])
                net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
                net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
                # Make sure that generator output is in the same range as `inputs`
                # ie [-1, 1].
                net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
    
                return net
    
    
    def unconditional_discriminator(img, unused_conditioning, weight_decay=2.5e-5,
                         is_training=True):
        """Discriminator network on MNIST digits.
    
        Args:
            img: Real or generated MNIST digits. Should be in the range [-1, 1].
            unused_conditioning: The TFGAN API can help with conditional GANs, which
                would require extra `condition` information to both the generator and the
                discriminator. Since this example is not conditional, we do not use this
                argument.
            weight_decay: The L2 weight decay.
            is_training: If `True`, batch norm uses batch statistics. If `False`, batch
                norm uses the exponential moving average collected from population 
                statistics.
    
        Returns:
            Logits for the probability that the image is real.
        """
        with tf.contrib.framework.arg_scope(
            [layers.conv2d, layers.fully_connected],
            activation_fn=tf.nn.relu, normalizer_fn=None,
            weights_regularizer=layers.l2_regularizer(weight_decay),
            biases_regularizer=layers.l2_regularizer(weight_decay)):
    
            net = layers.conv2d(img, 64, [4, 4], stride=2)
            net = layers.conv2d(net, 128, [4, 4], stride=2)
            net = layers.flatten(net)
    
            with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
                net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
    
            return layers.linear(net, 1)
    
    
    def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir):
        tf.reset_default_graph()
    
        if not tf.gfile.Exists(train_log_dir):
            tf.gfile.MakeDirs(train_log_dir)
    
        # Set up the input.
        images, _ = provide_data(batch_size)
        noise = tf.random_normal([batch_size, 64])
    
        with tf.name_scope('model'):
            # Build the generator and discriminator.
            gan_model = tfgan.gan_model(
                generator_fn=unconditional_generator,  # you define 
                discriminator_fn=unconditional_discriminator,  # you define
                real_data=images,
                generator_inputs=noise)
    
        with tf.name_scope('loss'):
            # Build the GAN loss.
            gan_loss = tfgan.gan_loss(
                gan_model,
                generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
                discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
                gradient_penalty_weight=1.0,
                add_summaries=True)
    
        with tf.name_scope('train'):
            # Create the train ops, which calculate gradients and apply updates to weights.
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
                discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
                check_for_unused_update_ops=False,
                summarize_gradients=True,
                aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
    
        # Run the train ops in the alternating training scheme.
        tfgan.gan_train(
            train_ops,
            hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
            logdir=train_log_dir,
            save_summaries_steps=10)
    
    
    def test(eval_dir, checkpoint_dir):
        tf.reset_default_graph()
    
        if not tf.gfile.Exists(eval_dir):
            tf.gfile.MakeDirs(eval_dir)
    
        random_inputs = tf.random_normal([100, 64])
    
        with tf.variable_scope('Generator'):
            images = unconditional_generator(random_inputs, is_training=False)
    
        reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
        uint8_images = float_image_to_uint8(reshaped_images)
    
        image_write_ops = tf.write_file(
              '%s/%s' % (eval_dir, 'unconditional_gan.png'),
              tf.image.encode_png(uint8_images[0]))
    
        tf.contrib.training.evaluate_repeatedly(
                checkpoint_dir,
                eval_ops=image_write_ops,
                hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
                max_number_of_evaluations=1)
    
    
    if __name__ == '__main__':
        train(16, 5000, 1e-4, 1e-4, 'logs/')
        test('eval/', 'logs/')
    
    

    相关文章

      网友评论

        本文标题:简单易用的轻量级生成对抗网络工具库:TFGAN

        本文链接:https://www.haomeiwen.com/subject/pxrhzftx.html