美文网首页
基于CycleGAN的性别变换方法

基于CycleGAN的性别变换方法

作者: LuDon | 来源:发表于2018-12-04 16:24 被阅读45次

GAN的简介

近年来,GAN(生成对抗式网络)成功地应用于图像生成、图像编辑和和表达学习等方面。最小化对抗损失使得生成的图像看起来真实。GAN的基本原理为:

  • 生成器G是生成图片的网络,接收一个随机的噪声z,生成图片G(z)。其目标是尽量生成真实的图片去欺骗判别网络D。
  • 判别器D是判别一张图片是否为真实。输入一张图片x,输出D(x)为x为真实图片的概率。其目的是尽量把生成器生成的图片和真实的图片区别出来。


    GAN网络

在理想情况下,生成器可以生成足以以假乱真的图片。而判别器难以辨别生成器生成的图片是否为真。

GAN的损失函数为:


GAN的损失函数

CycleGAN原理

图像与图像之间的变换

在传统的CNN方法中,图像与图像之间的变换是通过CNN来学习转移参数。
而本文的cycleGAN算法可以直接从一个图像生成另一个图像来实现图像之间的变换。

CycleGAN

目的:学习域X与域Y之间的映射关系。在CycleGAN模型中包括两个映射:X->Y, Y->X。如下图所示。


CycleGAN网络

在该网络中,存在两个域之间分别转换的生成器,以及每个生成器对应的判别器。目标函数中包括两项:

  • 对抗损失:使用控制生成的图像为目标域的图像。
对抗损失
  • cycle loss:为了防止两个生成器之间是相互矛盾的。


    cycle损失

在本项目中用来实现男女性别两个域之间的转换。

代码解析

### generator
conv(7, 7, 32)
conv(3, 3, 64)
conv(3, 3, 128)
res_block * 6 
deconv(3, 3, 64)
deconv(3, 3, 32)
conv(7, 7, 3)

### discriminator
conv(3, 3, 64)
conv(3, 3, 128)
conv(3, 3, 256)
conv(3, 3, 512)
conv(4, 4, 512)
### resnet_block
def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
    with tf.variable_scope(name):
        out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
        out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
        out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
        return tf.nn.relu(out_res + inputres)
### generator
def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        padding = "REFLECT"

        pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding)
        o_c1 = layers.general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
        o_c2 = layers.general_conv2d(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
        o_c3 = layers.general_conv2d(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")

        o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
        o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
        o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
        o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
        o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
        o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
        o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
        o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
        o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)

        o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4")
        o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,"SAME", "c5")
        o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1, 0.02, "SAME", "c6",do_norm=False, do_relu=False)

        if skip is True:
            out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
        else:
            out_gen = tf.nn.tanh(o_c6, "t1")

        return out_gen
### discriminator
def discriminator_tf(inputdisc, name="discriminator"):
    with tf.variable_scope(name):
        f = 4
        o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,0.02, "SAME", "c1", do_norm=False, relufactor=0.2)
        o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
        o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
        o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,0.02, "SAME", "c4", relufactor=0.2)
        o_c5 = layers.general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False
        )
        return o_c5
### layers.py
import tensorflow as tf
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
    with tf.variable_scope(name):
        if alt_relu_impl:
            f1 = 0.5 * (1 + leak)
            f2 = 0.5 * (1 - leak)
            return f1 * x + f2 * abs(x)
        else:
            return tf.maximum(x, leak * x)
def instance_norm(x):
    with tf.variable_scope("instance_norm"):
        epsilon = 1e-5
        mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
        scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02
        ))
        offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0)
        )
        out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
        return out
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
                   padding="VALID", name="conv2d", do_norm=True, do_relu=True,
                   relufactor=0):
    with tf.variable_scope(name):

        conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev
            ), biases_initializer=tf.constant_initializer(0.0))
        if do_norm:
            conv = instance_norm(conv)
        if do_relu:
            if(relufactor == 0):
                conv = tf.nn.relu(conv, "relu")
            else:
                conv = lrelu(conv, relufactor, "lrelu")
        return conv

def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
                     stddev=0.02, padding="VALID", name="deconv2d",
                     do_norm=True, do_relu=True, relufactor=0):
    with tf.variable_scope(name):

        conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None,weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer(0.0))
        if do_norm:
            conv = instance_norm(conv)
            # conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
            # updates_collections=None, epsilon=1e-5, scale=True,
            # scope="batch_norm")
        if do_relu:
            if(relufactor == 0):
                conv = tf.nn.relu(conv, "relu")
            else:
                conv = lrelu(conv, relufactor, "lrelu")
        return conv

测试的结果:


测试的结果

相关文章

  • 基于CycleGAN的性别变换方法

    GAN的简介 近年来,GAN(生成对抗式网络)成功地应用于图像生成、图像编辑和和表达学习等方面。最小化对抗损失使得...

  • 自然语言处理——7.8 词性标注方法

    · 基于规则的词性标注方法· 基于统计模型的词性标注方法· 规则和统计方法相结合的词性标注方法· 基于有限状态变换...

  • Deep-Learning-with-PyTorch-2.2.2

    2.2.2 循环GAN(CycleGAN) 这个概念的一个有趣的演变是CycleGAN。 CycleGAN可以将一...

  • 行人检测之初识

    行人检测,现在有基于全局特征的方法,基于人体部位的,基于立体的。 基于全局的是从边缘特征,形状特征,统计特征或变换...

  • 每天一个知识点(九)

    车道线的检测方法有很多,基于视觉的车道线检测有一下几种:基于霍夫之间检测、基于LSD直线检测、基于俯视图变换的车道...

  • "ColorGAN":paper code学

    2018-09-03关于CycleGAN优化发现 Augmented CycleGAN: Learning Man...

  • 跑通你的第一个深度学习程序(含环境搭建)。2021-09-15

    测试项目为CycleGAN项目源地址:https://github.com/xhujoy/CycleGAN-ten...

  • 6 变治法

    本章讨论的是一组设计方法,它们都基于变换的思想,我们把这种通用技术成为变治法。 根据我们对问题的变换方案,变治思想...

  • 列举几种常见的图像数据增广

    基于图像处理: 几何变换( Geometric Transformations)翻转变换( Flipping)移动...

  • CycleGAN

    前段时间看了一下CycleGan网络,最近还是想稍微总结一下,毕竟如果不记下的话很容易忘记。个人觉得玩Gan网络就...

网友评论

      本文标题:基于CycleGAN的性别变换方法

      本文链接:https://www.haomeiwen.com/subject/lvgiqqtx.html