美文网首页
基于CycleGAN的性别变换方法

基于CycleGAN的性别变换方法

作者: LuDon | 来源:发表于2018-12-04 16:24 被阅读45次

    GAN的简介

    近年来,GAN(生成对抗式网络)成功地应用于图像生成、图像编辑和和表达学习等方面。最小化对抗损失使得生成的图像看起来真实。GAN的基本原理为:

    • 生成器G是生成图片的网络,接收一个随机的噪声z,生成图片G(z)。其目标是尽量生成真实的图片去欺骗判别网络D。
    • 判别器D是判别一张图片是否为真实。输入一张图片x,输出D(x)为x为真实图片的概率。其目的是尽量把生成器生成的图片和真实的图片区别出来。


      GAN网络

    在理想情况下,生成器可以生成足以以假乱真的图片。而判别器难以辨别生成器生成的图片是否为真。

    GAN的损失函数为:


    GAN的损失函数

    CycleGAN原理

    图像与图像之间的变换

    在传统的CNN方法中,图像与图像之间的变换是通过CNN来学习转移参数。
    而本文的cycleGAN算法可以直接从一个图像生成另一个图像来实现图像之间的变换。

    CycleGAN

    目的:学习域X与域Y之间的映射关系。在CycleGAN模型中包括两个映射:X->Y, Y->X。如下图所示。


    CycleGAN网络

    在该网络中,存在两个域之间分别转换的生成器,以及每个生成器对应的判别器。目标函数中包括两项:

    • 对抗损失:使用控制生成的图像为目标域的图像。
    对抗损失
    • cycle loss:为了防止两个生成器之间是相互矛盾的。


      cycle损失

    在本项目中用来实现男女性别两个域之间的转换。

    代码解析

    ### generator
    conv(7, 7, 32)
    conv(3, 3, 64)
    conv(3, 3, 128)
    res_block * 6 
    deconv(3, 3, 64)
    deconv(3, 3, 32)
    conv(7, 7, 3)
    
    ### discriminator
    conv(3, 3, 64)
    conv(3, 3, 128)
    conv(3, 3, 256)
    conv(3, 3, 512)
    conv(4, 4, 512)
    
    ### resnet_block
    def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
        with tf.variable_scope(name):
            out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
            out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
            out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
            out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
            return tf.nn.relu(out_res + inputres)
    
    ### generator
    def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
        with tf.variable_scope(name):
            f = 7
            ks = 3
            padding = "REFLECT"
    
            pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding)
            o_c1 = layers.general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
            o_c2 = layers.general_conv2d(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
            o_c3 = layers.general_conv2d(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")
    
            o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
            o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
            o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
            o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
            o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
            o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
            o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
            o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
            o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)
    
            o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4")
            o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,"SAME", "c5")
            o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1, 0.02, "SAME", "c6",do_norm=False, do_relu=False)
    
            if skip is True:
                out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
            else:
                out_gen = tf.nn.tanh(o_c6, "t1")
    
            return out_gen
    
    ### discriminator
    def discriminator_tf(inputdisc, name="discriminator"):
        with tf.variable_scope(name):
            f = 4
            o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,0.02, "SAME", "c1", do_norm=False, relufactor=0.2)
            o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
            o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
            o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,0.02, "SAME", "c4", relufactor=0.2)
            o_c5 = layers.general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False
            )
            return o_c5
    
    ### layers.py
    import tensorflow as tf
    def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
        with tf.variable_scope(name):
            if alt_relu_impl:
                f1 = 0.5 * (1 + leak)
                f2 = 0.5 * (1 - leak)
                return f1 * x + f2 * abs(x)
            else:
                return tf.maximum(x, leak * x)
    def instance_norm(x):
        with tf.variable_scope("instance_norm"):
            epsilon = 1e-5
            mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
            scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02
            ))
            offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0)
            )
            out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
            return out
    def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
                       padding="VALID", name="conv2d", do_norm=True, do_relu=True,
                       relufactor=0):
        with tf.variable_scope(name):
    
            conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev
                ), biases_initializer=tf.constant_initializer(0.0))
            if do_norm:
                conv = instance_norm(conv)
            if do_relu:
                if(relufactor == 0):
                    conv = tf.nn.relu(conv, "relu")
                else:
                    conv = lrelu(conv, relufactor, "lrelu")
            return conv
    
    def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
                         stddev=0.02, padding="VALID", name="deconv2d",
                         do_norm=True, do_relu=True, relufactor=0):
        with tf.variable_scope(name):
    
            conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None,weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer(0.0))
            if do_norm:
                conv = instance_norm(conv)
                # conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
                # updates_collections=None, epsilon=1e-5, scale=True,
                # scope="batch_norm")
            if do_relu:
                if(relufactor == 0):
                    conv = tf.nn.relu(conv, "relu")
                else:
                    conv = lrelu(conv, relufactor, "lrelu")
            return conv
    

    测试的结果:


    测试的结果

    相关文章

      网友评论

          本文标题:基于CycleGAN的性别变换方法

          本文链接:https://www.haomeiwen.com/subject/lvgiqqtx.html