美文网首页
cyclegan代码学习

cyclegan代码学习

作者: 杨逸凡 | 来源:发表于2018-01-31 20:01 被阅读0次

    https://github.com/leehomyc/cyclegan-1
    https://junyanz.github.io/CycleGAN/

    cyclegan_datasets.py

    DATASET_TO_SIZES = {
        'horse2zebra_train': 1334,
        'horse2zebra_test': 140
    }
    
    """The image types of each dataset. Currently only supports .jpg or .png"""
    DATASET_TO_IMAGETYPE = {
        'horse2zebra_train': '.jpg',
        'horse2zebra_test': '.jpg',
    }
    
    """The path to the output csv file."""
    PATH_TO_CSV = {
        'horse2zebra_train': './input/horse2zebra/horse2zebra_train.csv',
        'horse2zebra_test': './input/horse2zebra/horse2zebra_test.csv',
    }
    

    数据保存在./input/horse2zebra,有四个目录:trainA, trainB, testA, testB

    create_cyclegan_dataset.py

    """Create datasets for training and testing."""
    import csv
    import os
    import random
    
    import click
    
    import cyclegan_datasets
    
    
    def create_list(foldername, fulldir=True, suffix=".jpg"):
        """
    
        :param foldername: The full path of the folder.
        :param fulldir: Whether to return the full path or not.
        :param suffix: Filter by suffix.
    
        :return: The list of filenames in the folder with given suffix.
    
        """
        file_list_tmp = os.listdir(foldername)
        file_list = []
        if fulldir:
            for item in file_list_tmp:
                if item.endswith(suffix):
                    file_list.append(os.path.join(foldername, item))
        else:
            for item in file_list_tmp:
                if item.endswith(suffix):
                    file_list.append(item)
        return file_list
    
    
    @click.command()
    @click.option('--image_path_a',
                  type=click.STRING,
                  default='./input/horse2zebra/trainA',
                  help='The path to the images from domain_a.')
    @click.option('--image_path_b',
                  type=click.STRING,
                  default='./input/horse2zebra/trainB',
                  help='The path to the images from domain_b.')
    @click.option('--dataset_name',
                  type=click.STRING,
                  default='horse2zebra_train',
                  help='The name of the dataset in cyclegan_dataset.')
    @click.option('--do_shuffle',
                  type=click.BOOL,
                  default=False,
                  help='Whether to shuffle images when creating the dataset.')
    def create_dataset(image_path_a, image_path_b,
                       dataset_name, do_shuffle):
        list_a = create_list(image_path_a, True,
                             cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
        list_b = create_list(image_path_b, True,
                             cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
    
        output_path = cyclegan_datasets.PATH_TO_CSV[dataset_name]
    
        num_rows = cyclegan_datasets.DATASET_TO_SIZES[dataset_name]
        all_data_tuples = []
        for i in range(num_rows):
            all_data_tuples.append((
                list_a[i % len(list_a)],
                list_b[i % len(list_b)]
            ))
        if do_shuffle is True:
            random.shuffle(all_data_tuples)
        with open(output_path, 'w') as csv_file:
            csv_writer = csv.writer(csv_file)
            for data_tuple in enumerate(all_data_tuples):
                csv_writer.writerow(list(data_tuple[1]))
    

    @click.command(), @click.optionargparse.ArgumentParser()作用相同。

    layers.py

    1. lrelu
    def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
    
        with tf.variable_scope(name):
            if alt_relu_impl:
                f1 = 0.5 * (1 + leak)
                f2 = 0.5 * (1 - leak)
                return f1 * x + f2 * abs(x)
            else:
                return tf.maximum(x, leak * x)
    

    两种是等价的,但是第一种占用内存更少。

    1. instance normalization
    def instance_norm(x):
    
        with tf.variable_scope("instance_norm"):
            epsilon = 1e-5
            mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
            scale = tf.get_variable('scale', [x.get_shape()[-1]],
                                    initializer=tf.truncated_normal_initializer(
                                        mean=1.0, stddev=0.02
            ))
            offset = tf.get_variable(
                'offset', [x.get_shape()[-1]],
                initializer=tf.constant_initializer(0.0)
            )
            out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
    
            return out
    

    instance normalization使用单一图片作为输入,在GAN,style transfer这类任务上IN的实验结论要优于BN,给出的普遍的阐述性解释是:这类生成式方法,自己的风格比较独立不应该与batch中其他的样本产生太大联系。
    axis、scale、offset可以参考前一篇Batch Normalization的部分。

    1. 卷积层
    def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
                       padding="VALID", name="conv2d", do_norm=True, do_relu=True,
                       relufactor=0):
        with tf.variable_scope(name):
    
            conv = tf.contrib.layers.conv2d(
                inputconv, o_d, f_w, s_w, padding,
                activation_fn=None,
                weights_initializer=tf.truncated_normal_initializer(
                    stddev=stddev
                ),
                biases_initializer=tf.constant_initializer(0.0)
            )
            if do_norm:
                conv = instance_norm(conv)
    
            if do_relu:
                if(relufactor == 0):
                    conv = tf.nn.relu(conv, "relu")
                else:
                    conv = lrelu(conv, relufactor, "lrelu")
    
            return conv
    

    tf.truncated_normal_initializer: 如果生成的值大于平均值2个标准偏差的值则丢弃重新选择

    1. 反卷积层
    def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
                         stddev=0.02, padding="VALID", name="deconv2d",
                         do_norm=True, do_relu=True, relufactor=0):
        with tf.variable_scope(name):
    
            conv = tf.contrib.layers.conv2d_transpose(
                inputconv, o_d, [f_h, f_w],
                [s_h, s_w], padding,
                activation_fn=None,
                weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
                biases_initializer=tf.constant_initializer(0.0)
            )
    
            if do_norm:
                conv = instance_norm(conv)
                # conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
                # updates_collections=None, epsilon=1e-5, scale=True,
                # scope="batch_norm")
    
            if do_relu:
                if(relufactor == 0):
                    conv = tf.nn.relu(conv, "relu")
                else:
                    conv = lrelu(conv, relufactor, "lrelu")
    
            return conv
    

    losses.py

    """Contains losses used for performing image-to-image domain adaptation."""
    import tensorflow as tf
    
    # L(G, F)
    def cycle_consistency_loss(real_images, generated_images):
        """Compute the cycle consistency loss.
    
        The cycle consistency loss is defined as the sum of the L1 distances
        between the real images from each domain and their generated (fake)
        counterparts.
    
        This definition is derived from Equation 2 in:
            Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
            Networks.
            Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros.
    
    
        Args:
            real_images: A batch of images from domain X, a `Tensor` of shape
                [batch_size, height, width, channels].
            generated_images: A batch of generated images made to look like they
                came from domain X, a `Tensor` of shape
                [batch_size, height, width, channels].
    
        Returns:
            The cycle consistency loss.
        """
        return tf.reduce_mean(tf.abs(real_images - generated_images))
    
    
    def lsgan_loss_generator(prob_fake_is_real):
        """Computes the LS-GAN loss as minimized by the generator.
    
        Rather than compute the negative loglikelihood, a least-squares loss is
        used to optimize the discriminators as per Equation 2 in:
            Least Squares Generative Adversarial Networks
            Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
            Stephen Paul Smolley.
            https://arxiv.org/pdf/1611.04076.pdf
    
        Args:
            prob_fake_is_real: The discriminator's estimate that generated images
                made to look like real images are real.
    
        Returns:
            The total LS-GAN loss.
        """
        return tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 1))
    
    
    def lsgan_loss_discriminator(prob_real_is_real, prob_fake_is_real):
        """Computes the LS-GAN loss as minimized by the discriminator.
    
        Rather than compute the negative loglikelihood, a least-squares loss is
        used to optimize the discriminators as per Equation 2 in:
            Least Squares Generative Adversarial Networks
            Xudong Mao, Qing Li, Haoran Xie, Raymond Y.K. Lau, Zhen Wang, and
            Stephen Paul Smolley.
            https://arxiv.org/pdf/1611.04076.pdf
    
        Args:
            prob_real_is_real: The discriminator's estimate that images actually
                drawn from the real domain are in fact real.
            prob_fake_is_real: The discriminator's estimate that generated images
                made to look like real images are real.
    
        Returns:
            The total LS-GAN loss.
        """
        return (tf.reduce_mean(tf.squared_difference(prob_real_is_real, 1)) +
                tf.reduce_mean(tf.squared_difference(prob_fake_is_real, 0))) * 0.5
    

    model.py

    1. ResNet block
    def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
        """build a single block of resnet.
    
        :param inputres: inputres
        :param dim: dim
        :param name: name
        :param padding: for tensorflow version use REFLECT; for pytorch version use
         CONSTANT
        :return: a single block of resnet.
        """
        with tf.variable_scope(name):
            out_res = tf.pad(inputres, [[0, 0], [1, 1], [
                1, 1], [0, 0]], padding)
            out_res = layers.general_conv2d(
                out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
            out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
            out_res = layers.general_conv2d(
                out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
    
            return tf.nn.relu(out_res + inputres)
    

    ResNet Block:



    Reflection Padding:

    t = tf.constant([[1, 2, 3], [4, 5, 6]])
    paddings = tf.constant([[1, 1,], [2, 2]])
    tf.pad(t, paddings, "REFLECT")  # [[6, 5, 4, 5, 6, 5, 4],
                                    #  [3, 2, 1, 2, 3, 2, 1],
                                    #  [6, 5, 4, 5, 6, 5, 4],
                                    #  [3, 2, 1, 2, 3, 2, 1]]
    

    mode="REFLECT"是映射填充,上下(1维)填充顺序和paddings是相反的,左右(零维)顺序补齐。

    1. Generator
    def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
        with tf.variable_scope(name):
            f = 7
            ks = 3
            padding = "REFLECT"
    
            pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [
                ks, ks], [0, 0]], padding)
            o_c1 = layers.general_conv2d(
                pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
            o_c2 = layers.general_conv2d(
                o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
            o_c3 = layers.general_conv2d(
                o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")
    
            o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
            o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
            o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
            o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
            o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
            o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
            o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
            o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
            o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)
    
            o_c4 = layers.general_deconv2d(
                o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02,
                "SAME", "c4")
            o_c5 = layers.general_deconv2d(
                o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,
                "SAME", "c5")
            o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1,
                                         0.02, "SAME", "c6",
                                         do_norm=False, do_relu=False)
    
            if skip is True:
                out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
            else:
                out_gen = tf.nn.tanh(o_c6, "t1")
    
            return out_gen
    

    "We use 6 blocks for 128 × 128 training images, and 9 blocks for 256 × 256 or higher-resolution training images.

    Let c7s1-k denote a 7 × 7 Convolution-InstanceNorm-ReLU layer with k filters and stride 1.
    dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters, and stride 2.
    Reflection padding was used to reduce artifacts.
    Rk denotes a residual block that contains two 3 × 3 convolutional layers with the same number of filters on both layer.
    uk denotes a 3 × 3 fractional-strided-Convolution-InstanceNorm-ReLU layer with k filters, and stride 12 .

    The network with 6 blocks consists of:
    c7s1-32,d64,d128,R128,R128,R128,R128,R128,R128,u64,u32,c7s1-3
    The network with 9 blocks consists of:
    c7s1-32,d64,d128,R128,R128,R128,R128,R128,R128,R128,R128,R128,u64,u32,c7s1-3"

    1. Discriminator
    def discriminator_tf(inputdisc, name="discriminator"):
        with tf.variable_scope(name):
            f = 4
    
            o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,
                                         0.02, "SAME", "c1", do_norm=False,
                                         relufactor=0.2)
            o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2,
                                         0.02, "SAME", "c2", relufactor=0.2)
            o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2,
                                         0.02, "SAME", "c3", relufactor=0.2)
            o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,
                                         0.02, "SAME", "c4", relufactor=0.2)
            o_c5 = layers.general_conv2d(
                o_c4, 1, f, f, 1, 1, 0.02,
                "SAME", "c5", do_norm=False, do_relu=False
            )
    
            return o_c5
    

    "For discriminator networks, we use 70 × 70 PatchGAN [21]. Let Ck denote a 4 × 4 Convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2. After the last layer, we apply a convolution to produce a 1 dimensional output. We do not use InstanceNorm for the first C64 layer. We use leaky ReLUs with slope 0:2. The discriminator architecture is:
    C64-C128-C256-C512"

    1. PatchGAN
    def patch_discriminator(inputdisc, name="discriminator"):
        with tf.variable_scope(name):
            f = 4
    
            patch_input = tf.random_crop(inputdisc, [1, 70, 70, 3])
            o_c1 = layers.general_conv2d(patch_input, ndf, f, f, 2, 2,
                                         0.02, "SAME", "c1", do_norm="False",
                                         relufactor=0.2)
            o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2,
                                         0.02, "SAME", "c2", relufactor=0.2)
            o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2,
                                         0.02, "SAME", "c3", relufactor=0.2)
            o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 2, 2,
                                         0.02, "SAME", "c4", relufactor=0.2)
            o_c5 = layers.general_conv2d(
                o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False,
                do_relu=False)
    
            return o_c5
    
    1. 输出
    def get_outputs(inputs, network="tensorflow", skip=False):
        images_a = inputs['images_a']
        images_b = inputs['images_b']
    
        fake_pool_a = inputs['fake_pool_a']
        fake_pool_b = inputs['fake_pool_b']
    
        with tf.variable_scope("Model") as scope:
    
            if network == "pytorch":
                current_discriminator = discriminator
                current_generator = build_generator_resnet_9blocks
            elif network == "tensorflow":
                current_discriminator = discriminator_tf
                current_generator = build_generator_resnet_9blocks_tf
            else:
                raise ValueError(
                    'network must be either pytorch or tensorflow'
                )
    
            prob_real_a_is_real = current_discriminator(images_a, "d_A")
            prob_real_b_is_real = current_discriminator(images_b, "d_B")
    
            fake_images_b = current_generator(images_a, name="g_A", skip=skip)
            fake_images_a = current_generator(images_b, name="g_B", skip=skip)
    
            scope.reuse_variables()
    
            prob_fake_a_is_real = current_discriminator(fake_images_a, "d_A")
            prob_fake_b_is_real = current_discriminator(fake_images_b, "d_B")
    
            cycle_images_a = current_generator(fake_images_b, "g_B", skip=skip)
            cycle_images_b = current_generator(fake_images_a, "g_A", skip=skip)
    
            scope.reuse_variables()
    
            prob_fake_pool_a_is_real = current_discriminator(fake_pool_a, "d_A")
            prob_fake_pool_b_is_real = current_discriminator(fake_pool_b, "d_B")
    
        return {
            'prob_real_a_is_real': prob_real_a_is_real,
            'prob_real_b_is_real': prob_real_b_is_real,
            'prob_fake_a_is_real': prob_fake_a_is_real,
            'prob_fake_b_is_real': prob_fake_b_is_real,
            'prob_fake_pool_a_is_real': prob_fake_pool_a_is_real,
            'prob_fake_pool_b_is_real': prob_fake_pool_b_is_real,
            'cycle_images_a': cycle_images_a,
            'cycle_images_b': cycle_images_b,
            'fake_images_a': fake_images_a,
            'fake_images_b': fake_images_b,
        }
    

    A: 真马集 images_a(A) -> fake_images_b(fB) -> cycle_images_a
    B: 真斑马集 images_b(B) -> fake_images_a(fA) -> cycle_images_b
    fA: 假马集 fake_pool_a
    fB: 假斑马集 fake_pool_b

    data_loader.py

    1. load sample
    import tensorflow as tf
    
    import cyclegan_datasets
    import model
    
    
    def _load_samples(csv_name, image_type):
        filename_queue = tf.train.string_input_producer(
            [csv_name])
    
        reader = tf.TextLineReader()
        _, csv_filename = reader.read(filename_queue)
    
        record_defaults = [tf.constant([], dtype=tf.string),
                           tf.constant([], dtype=tf.string)]
    
        filename_i, filename_j = tf.decode_csv(
            csv_filename, record_defaults=record_defaults)
    
        file_contents_i = tf.read_file(filename_i)
        file_contents_j = tf.read_file(filename_j)
        if image_type == '.jpg':
            image_decoded_A = tf.image.decode_jpeg(
                file_contents_i, channels=model.IMG_CHANNELS)
            image_decoded_B = tf.image.decode_jpeg(
                file_contents_j, channels=model.IMG_CHANNELS)
        elif image_type == '.png':
            image_decoded_A = tf.image.decode_png(
                file_contents_i, channels=model.IMG_CHANNELS, dtype=tf.uint8)
            image_decoded_B = tf.image.decode_png(
                file_contents_j, channels=model.IMG_CHANNELS, dtype=tf.uint8)
    
        return image_decoded_A, image_decoded_B
    

    和之前pixel2pixel的load_sample过程类似,只不过这里reader是TextLineReader()(因为一行是一组文件)。在这里对csv的读取也是标准流程,参考:

    https://www.jianshu.com/p/d063804fb272
    http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html

    在调用run或者eval去执行read之前, 必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

    # 流程
    q = tf.train.string_input_producer([csv1, csv2, ...])
    # string_input_producer同时打开多个文件,显式创建Queue,同时隐含了QueueRunner的创建
    reader = tf.TextLineReader()
    content = reader.read(q)
    record_defaults = [[], [], ...]
    tf.decode_csv(content, record_defaults=record_defaults)
    ...
    coord = tf.train.Coordinator()
    # 创建coordinator
    threads = tf.train.start_queue_runners(coord=coord)
    # 启动计算图中所有的队列线程
    
    1. load data
    def load_data(dataset_name, image_size_before_crop,
                  do_shuffle=True, do_flipping=False):
        """
    
        :param dataset_name: The name of the dataset.
        :param image_size_before_crop: Resize to this size before random cropping.
        :param do_shuffle: Shuffle switch.
        :param do_flipping: Flip switch.
        :return:
        """
        if dataset_name not in cyclegan_datasets.DATASET_TO_SIZES:
            raise ValueError('split name %s was not recognized.'
                             % dataset_name)
    
        csv_name = cyclegan_datasets.PATH_TO_CSV[dataset_name]
    
        image_i, image_j = _load_samples(
            csv_name, cyclegan_datasets.DATASET_TO_IMAGETYPE[dataset_name])
        inputs = {
            'image_i': image_i,
            'image_j': image_j
        }
    
        # Preprocessing:
        inputs['image_i'] = tf.image.resize_images(
            inputs['image_i'], [image_size_before_crop, image_size_before_crop])
        inputs['image_j'] = tf.image.resize_images(
            inputs['image_j'], [image_size_before_crop, image_size_before_crop])
    
        if do_flipping is True:
            inputs['image_i'] = tf.image.random_flip_left_right(inputs['image_i'])
            inputs['image_j'] = tf.image.random_flip_left_right(inputs['image_j'])
    
        inputs['image_i'] = tf.random_crop(
            inputs['image_i'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
        inputs['image_j'] = tf.random_crop(
            inputs['image_j'], [model.IMG_HEIGHT, model.IMG_WIDTH, 3])
    
        inputs['image_i'] = tf.subtract(tf.div(inputs['image_i'], 127.5), 1)
        inputs['image_j'] = tf.subtract(tf.div(inputs['image_j'], 127.5), 1)
    
        # Batch
        if do_shuffle is True:
            inputs['images_i'], inputs['images_j'] = tf.train.shuffle_batch(
                [inputs['image_i'], inputs['image_j']], 1, 5000, 100)
        else:
            inputs['images_i'], inputs['images_j'] = tf.train.batch(
                [inputs['image_i'], inputs['image_j']], 1)
    
        return inputs
    

    tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue),capacity是队列中的容量,min_after_capacity是出队后,队列至少剩下min_after_dequeue个数据。

    main.py

    代码很易懂,注释也写的很清楚。

    from datetime import datetime
    import json
    import numpy as np
    import os
    import random
    from scipy.misc import imsave
    
    import click
    import tensorflow as tf
    
    import cyclegan_datasets
    import data_loader, losses, model
    
    slim = tf.contrib.slim
    
    
    class CycleGAN:
        """The CycleGAN module."""
        ...
    
    1. 初始化
        def __init__(self, pool_size, lambda_a,
                     lambda_b, output_root_dir, to_restore,
                     base_lr, max_step, network_version,
                     dataset_name, checkpoint_dir, do_flipping, skip):
            current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    
            self._pool_size = pool_size
            self._size_before_crop = 286
            self._lambda_a = lambda_a # cycleloss前的系数lambda
            self._lambda_b = lambda_b
            self._output_dir = os.path.join(output_root_dir, current_time)
            self._images_dir = os.path.join(self._output_dir, 'imgs')
            self._num_imgs_to_save = 20
            self._to_restore = to_restore
            self._base_lr = base_lr
            self._max_step = max_step
            self._network_version = network_version
            self._dataset_name = dataset_name
            self._checkpoint_dir = checkpoint_dir
            self._do_flipping = do_flipping
            self._skip = skip
    
            self.fake_images_A = np.zeros(
                (self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
                 model.IMG_CHANNELS)
            )
            self.fake_images_B = np.zeros(
                (self._pool_size, 1, model.IMG_HEIGHT, model.IMG_WIDTH,
                 model.IMG_CHANNELS)
            )
        def model_setup(self):
            """
            This function sets up the model to train.
    
            self.input_A/self.input_B -> Set of training images.
            self.fake_A/self.fake_B -> Generated images by corresponding generator
            of input_A and input_B
            self.lr -> Learning rate variable
            self.cyc_A/ self.cyc_B -> Images generated after feeding
            self.fake_A/self.fake_B to corresponding generator.
            This is use to calculate cyclic loss
            """
            self.input_a = tf.placeholder(
                tf.float32, [
                    1,
                    model.IMG_WIDTH,
                    model.IMG_HEIGHT,
                    model.IMG_CHANNELS
                ], name="input_A")
            self.input_b = tf.placeholder(
                tf.float32, [
                    1,
                    model.IMG_WIDTH,
                    model.IMG_HEIGHT,
                    model.IMG_CHANNELS
                ], name="input_B")
    
            self.fake_pool_A = tf.placeholder(
                tf.float32, [
                    None,
                    model.IMG_WIDTH,
                    model.IMG_HEIGHT,
                    model.IMG_CHANNELS
                ], name="fake_pool_A")
            self.fake_pool_B = tf.placeholder(
                tf.float32, [
                    None,
                    model.IMG_WIDTH,
                    model.IMG_HEIGHT,
                    model.IMG_CHANNELS
                ], name="fake_pool_B")
    
            self.global_step = slim.get_or_create_global_step()
    
            self.num_fake_inputs = 0
    
            self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")
    
            inputs = {
                'images_a': self.input_a,
                'images_b': self.input_b,
                'fake_pool_a': self.fake_pool_A,
                'fake_pool_b': self.fake_pool_B,
            }
    
            outputs = model.get_outputs(
                inputs, network=self._network_version, skip=self._skip)
    
            self.prob_real_a_is_real = outputs['prob_real_a_is_real']
            self.prob_real_b_is_real = outputs['prob_real_b_is_real']
            self.fake_images_a = outputs['fake_images_a']
            self.fake_images_b = outputs['fake_images_b']
            self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
            self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']
    
            self.cycle_images_a = outputs['cycle_images_a']
            self.cycle_images_b = outputs['cycle_images_b']
    
            self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
            self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']
    
    1. 计算代价


        def compute_losses(self):
            """
            In this function we are defining the variables for loss calculations
            and training model.
    
            d_loss_A/d_loss_B -> loss for discriminator A/B
            g_loss_A/g_loss_B -> loss for generator A/B
            *_trainer -> Various trainer for above loss functions
            *_summ -> Summary variables for above loss functions
            """
            cycle_consistency_loss_a = \
                self._lambda_a * losses.cycle_consistency_loss(
                    real_images=self.input_a, generated_images=self.cycle_images_a,
                )
            cycle_consistency_loss_b = \
                self._lambda_b * losses.cycle_consistency_loss(
                    real_images=self.input_b, generated_images=self.cycle_images_b,
                )
    
            lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
            lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)
    
            g_loss_A = \
                cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
            g_loss_B = \
                cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a
    
            d_loss_A = losses.lsgan_loss_discriminator(
                prob_real_is_real=self.prob_real_a_is_real,
                prob_fake_is_real=self.prob_fake_pool_a_is_real,
            )
            d_loss_B = losses.lsgan_loss_discriminator(
                prob_real_is_real=self.prob_real_b_is_real,
                prob_fake_is_real=self.prob_fake_pool_b_is_real,
            )
    
            optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)
    
            self.model_vars = tf.trainable_variables()
    
            d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
            g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
            d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
            g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]
    
            self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
            self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
            self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
            self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
    
            for var in self.model_vars:
                print(var.name)
    
            # Summary variables for tensorboard
            self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
            self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
            self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
            self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
    
    1. 缓存生成图片
        def fake_image_pool(self, num_fakes, fake, fake_pool):
            """
            This function saves the generated image to corresponding
            pool of images.
    
            It keeps on feeling the pool till it is full and then randomly
            selects an already stored image and replace it with new one.
            """
            if num_fakes < self._pool_size:
                fake_pool[num_fakes] = fake
                return fake
            else:
                p = random.random()
                if p > 0.5:
                    random_id = random.randint(0, self._pool_size - 1)
                    temp = fake_pool[random_id]
                    fake_pool[random_id] = fake
                    return temp
                else:
                    return fake
    
    1. 训练
        def train(self):
            """Training Function."""
            # Load Dataset from the dataset folder
            self.inputs = data_loader.load_data(
                self._dataset_name, self._size_before_crop,
                True, self._do_flipping)
    
            # Build the network
            self.model_setup()
    
            # Loss function calculations
            self.compute_losses()
    
            # Initializing the global variables
            init = (tf.global_variables_initializer(),
                    tf.local_variables_initializer())
            saver = tf.train.Saver()
    
            max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]
    
            with tf.Session() as sess:
                sess.run(init)
    
                # Restore the model to run the model from last checkpoint
                if self._to_restore:
                    chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                    saver.restore(sess, chkpt_fname)
    
                writer = tf.summary.FileWriter(self._output_dir)
    
                if not os.path.exists(self._output_dir):
                    os.makedirs(self._output_dir)
    
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)
    
                # Training Loop
                for epoch in range(sess.run(self.global_step), self._max_step):
                    print("In the epoch ", epoch)
                    saver.save(sess, os.path.join(
                        self._output_dir, "cyclegan"), global_step=epoch)
    
                    # Dealing with the learning rate as per the epoch number
                    if epoch < 100:
                        curr_lr = self._base_lr
                    else:
                        curr_lr = self._base_lr - \
                            self._base_lr * (epoch - 100) / 100
    
                    self.save_images(sess, epoch)
    
                    for i in range(0, max_images):
                        print("Processing batch {}/{}".format(i, max_images))
    
                        inputs = sess.run(self.inputs)
    
                        # Optimizing the G_A network
                        _, fake_B_temp, summary_str = sess.run(
                            [self.g_A_trainer,
                             self.fake_images_b,
                             self.g_A_loss_summ],
                            feed_dict={
                                self.input_a:
                                    inputs['images_i'],
                                self.input_b:
                                    inputs['images_j'],
                                self.learning_rate: curr_lr
                            }
                        )
                        writer.add_summary(summary_str, epoch * max_images + i)
    
                        fake_B_temp1 = self.fake_image_pool(
                            self.num_fake_inputs, fake_B_temp, self.fake_images_B)
    
                        # Optimizing the D_B network
                        _, summary_str = sess.run(
                            [self.d_B_trainer, self.d_B_loss_summ],
                            feed_dict={
                                self.input_a:
                                    inputs['images_i'],
                                self.input_b:
                                    inputs['images_j'],
                                self.learning_rate: curr_lr,
                                self.fake_pool_B: fake_B_temp1
                            }
                        )
                        writer.add_summary(summary_str, epoch * max_images + i)
    
                        # Optimizing the G_B network
                        _, fake_A_temp, summary_str = sess.run(
                            [self.g_B_trainer,
                             self.fake_images_a,
                             self.g_B_loss_summ],
                            feed_dict={
                                self.input_a:
                                    inputs['images_i'],
                                self.input_b:
                                    inputs['images_j'],
                                self.learning_rate: curr_lr
                            }
                        )
                        writer.add_summary(summary_str, epoch * max_images + i)
    
                        fake_A_temp1 = self.fake_image_pool(
                            self.num_fake_inputs, fake_A_temp, self.fake_images_A)
    
                        # Optimizing the D_A network
                        _, summary_str = sess.run(
                            [self.d_A_trainer, self.d_A_loss_summ],
                            feed_dict={
                                self.input_a:
                                    inputs['images_i'],
                                self.input_b:
                                    inputs['images_j'],
                                self.learning_rate: curr_lr,
                                self.fake_pool_A: fake_A_temp1
                            }
                        )
                        writer.add_summary(summary_str, epoch * max_images + i)
    
                        writer.flush()
                        self.num_fake_inputs += 1
    
                    sess.run(tf.assign(self.global_step, epoch + 1))
    
                coord.request_stop()
                coord.join(threads)
                writer.add_graph(sess.graph)
    
    1. 测试
        def test(self):
            """Test Function."""
            print("Testing the results")
    
            self.inputs = data_loader.load_data(
                self._dataset_name, self._size_before_crop,
                False, self._do_flipping)
    
            self.model_setup()
            saver = tf.train.Saver()
            init = tf.global_variables_initializer()
    
            with tf.Session() as sess:
                sess.run(init)
    
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                saver.restore(sess, chkpt_fname)
    
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)
    
                self._num_imgs_to_save = cyclegan_datasets.DATASET_TO_SIZES[
                    self._dataset_name]
                self.save_images(sess, 0)
    
                coord.request_stop()
                coord.join(threads)
    
    1. 主程序
    @click.command()
    @click.option('--to_train',
                  type=click.INT,
                  default=True,
                  help='Whether it is train or false.')
    @click.option('--log_dir',
                  type=click.STRING,
                  default=None,
                  help='Where the data is logged to.')
    @click.option('--config_filename',
                  type=click.STRING,
                  default='train',
                  help='The name of the configuration file.')
    @click.option('--checkpoint_dir',
                  type=click.STRING,
                  default='',
                  help='The name of the train/test split.')
    @click.option('--skip',
                  type=click.BOOL,
                  default=False,
                  help='Whether to add skip connection between input and output.')
    def main(to_train, log_dir, config_filename, checkpoint_dir, skip):
        """
    
        :param to_train: Specify whether it is training or testing. 1: training; 2:
         resuming from latest checkpoint; 0: testing.
        :param log_dir: The root dir to save checkpoints and imgs. The actual dir
        is the root dir appended by the folder with the name timestamp.
        :param config_filename: The configuration file.
        :param checkpoint_dir: The directory that saves the latest checkpoint. It
        only takes effect when to_train == 2.
        :param skip: A boolean indicating whether to add skip connection between
        input and output.
        """
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
    
        with open(config_filename) as config_file:
            config = json.load(config_file)
    
        lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0
        lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0
        pool_size = int(config['pool_size']) if 'pool_size' in config else 50
    
        to_restore = (to_train == 2)
        base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002
        max_step = int(config['max_step']) if 'max_step' in config else 200
        network_version = str(config['network_version'])
        dataset_name = str(config['dataset_name'])
        do_flipping = bool(config['do_flipping'])
    
        cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,
                                  to_restore, base_lr, max_step, network_version,
                                  dataset_name, checkpoint_dir, do_flipping, skip)
    
        if to_train > 0:
            cyclegan_model.train()
        else:
            cyclegan_model.test()
    
    1. 保存图片
        def save_images(self, sess, epoch):
            """
            Saves input and output images.
    
            :param sess: The session.
            :param epoch: Currnt epoch.
            """
            if not os.path.exists(self._images_dir):
                os.makedirs(self._images_dir)
    
            names = ['inputA_', 'inputB_', 'fakeA_',
                     'fakeB_', 'cycA_', 'cycB_']
    
            with open(os.path.join(
                    self._output_dir, 'epoch_' + str(epoch) + '.html'
            ), 'w') as v_html:
                for i in range(0, self._num_imgs_to_save):
                    print("Saving image {}/{}".format(i, self._num_imgs_to_save))
                    inputs = sess.run(self.inputs)
                    fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([
                        self.fake_images_a,
                        self.fake_images_b,
                        self.cycle_images_a,
                        self.cycle_images_b
                    ], feed_dict={
                        self.input_a: inputs['images_i'],
                        self.input_b: inputs['images_j']
                    })
    
                    tensors = [inputs['images_i'], inputs['images_j'],
                               fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp]
    
                    for name, tensor in zip(names, tensors):
                        image_name = name + str(epoch) + "_" + str(i) + ".jpg"
                        imsave(os.path.join(self._images_dir, image_name),
                               ((tensor[0] + 1) * 127.5).astype(np.uint8)
                               )
                        v_html.write(
                            "<img src=\"" +
                            os.path.join('imgs', image_name) + "\">"
                        )
                    v_html.write("<br>")
    

    结果

    我自己在服务器上跑了100个epoch后的结果:
    马 -> 斑马 -> 马



    斑马 -> 马 -> 斑马



    这还算是比较好的结果,有的惨不忍睹:

    总体来说,马->斑马远远好于斑马->马,而且生成的马身上仍然条纹很多,不知道训练更长时间会不会好一些。

    相关文章

      网友评论

          本文标题:cyclegan代码学习

          本文链接:https://www.haomeiwen.com/subject/cyaczxtx.html