美文网首页
CycleGAN 代码

CycleGAN 代码

作者: 晨光523152 | 来源:发表于2020-05-06 10:40 被阅读0次

五一前把CycleGAN的文章看了,今天来把代码补上,遗憾的是这个程序虽然能运行,但是我由于计算资源的限制,没法跑完200个epoch,所以没办法把最终的结果展示出来。

tensorflow 官方给出了代码,但是里面用的是Unet做生成器和鉴别器,与原文有所出入,其中关于对抗损失函数一块与原文出入比较大
tensorflow 官方的传送门:https://www.tensorflow.org/tutorials/generative/cyclegan

我用的代码如下:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.regularizers import l2
import numpy as np
import math
epsilon = 1e-5
def conv(numout, kernel_size=3, strides=1,kernel_regularizer = 0.0005, padding='same',use_bias=False,name='conv'):
  return tf.keras.layers.Conv2D(name=name,filters=numout,kernel_size=kernel_size,strides=strides,padding=padding,use_bias=use_bias,kernel_regularizer=l2(kernel_regularizer),kernel_initializer=tf.random_normal_initializer(stddev=0.1))

def convt(numout,kernel_size=3,strides=1,kernel_regularizer=0.0005,padding='same',use_bias=False,name='conv'):
    return tf.keras.layers.Conv2DTranspose(name=name,filters=numout, kernel_size=kernel_size,strides=strides, padding=padding,use_bias=use_bias, kernel_regularizer=l2(kernel_regularizer),kernel_initializer=tf.random_normal_initializer(stddev=0.1)) 

def bn(name,momentum=0.9):
    return tf.keras.layers.BatchNormalization(name=name,momentum=momentum)
class c7s1_k(keras.Model):
    def __init__(self,scope: str="c7s1_k",k:int =16,reg:float=0.0005,norm:str="instance"):
        super(c7s1_k, self).__init__(name=scope)
        self.conv1 = conv(numout=k,kernel_size=7,kernel_regularizer=reg,padding='valid',name='conv')
        self.norm =norm
        if norm is 'instance':
            self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
            self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
        elif norm is 'bn':
            self.bn1 = bn(name='bn')
    def call(self,x,training=False,activation='Relu'):
        x = tf.pad(x, [[0,0],[3,3],[3,3],[0,0]], 'REFLECT')
        x = self.conv1(x)
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
        elif self.norm is 'bn':
            x = self.bn1(x,training=training)
        if activation is 'Relu':
            x = tf.nn.relu(x)
        else:
            x = tf.nn.tanh(x)
        return x  

class dk(keras.Model):
    def __init__(self,scope: str="dk",k:int =16,reg:float=0.0005,norm:str="instance"):
        super(dk, self).__init__(name=scope)
        self.norm =norm
        self.conv1 = conv(numout=k,kernel_size=3,strides=[2, 2],kernel_regularizer=reg,padding='same',name='conv')
        if norm is 'instance':
            self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
            self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
        elif norm is 'bn':
            self.bn1 = bn(name='bn')
    def call(self,x,training=False):
        x = self.conv1(x)
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
        elif self.norm is 'bn':
            x = self.bn1(x,training=training)
        x = tf.nn.relu(x)
        return x 

class Rk(keras.Model):
    def __init__(self,scope: str="Rk",k:int =16,reg:float=0.0005,norm:str="instance"):
        super(Rk, self).__init__(name=scope)
        self.norm =norm
        self.conv1 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='layer1/conv')
        if norm is 'instance':
            self.scale1 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer1/scale')
            self.offset1 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer1/offset')
            self.scale2 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer2/scale')
            self.offset2 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer2/offset')
        elif norm is 'bn':
            self.bn1 = bn(name='layer1/bn')
            self.bn2 = bn(name='layer2/bn')
        self.conv2 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='layer2/conv')

    def call(self,x,training=False):
        inputs = x
        x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
        x = self.conv1(x)
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale1 * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset1
        elif self.norm is 'bn':
            x = self.bn1(x,training=training)
        x = tf.nn.relu(x)
        x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
        x = self.conv2(x)       
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale2 * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset2
        elif self.norm is 'bn':
            x = self.bn2(x,training=training)       
        return x + inputs   

class n_res_blocks(keras.Model):
    def __init__(self,scope: str="n_res_blocks",n:int =6,k:int=16,reg:float=0.0005,norm:str="instance"):
        super(n_res_blocks, self).__init__(name=scope)
        self.group=[]
        self.norm =norm
        for i in range(n):
            self.group.append(Rk(scope='Rk_'+str(i+1),k=k,reg=reg,norm=norm))
    def call(self,x,training=False):
        for i in range(len(self.group)):
            x = self.group[i](x,training=training)
        return x 

class uk(keras.Model):
    def __init__(self,scope: str="uk",k:int =16,reg:float=0.0005,norm:str="instance"):
        super(uk, self).__init__(name=scope)
        self.norm =norm
        #self.conv1 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='conv')
        self.conv1 = convt(numout=k,kernel_size=3,strides=[ 2 , 2 ],kernel_regularizer=reg,padding='same',name='conv')
        if norm is 'instance':
            self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
            self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
        elif norm is 'bn':
            self.bn1 = bn(name='bn')
    def call(self,x,training=False):
        #height = x.shape[1]
        #width = x.shape[2]
        #x=tf.compat.v1.image.resize_images(x, [2*height,2*width],method = 0, align_corners = True)
        #x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
        x = self.conv1(x)
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
        elif self.norm is 'bn':
            x = self.bn1(x,training=training)
        x = tf.nn.relu(x)
        return x            

class Ck(keras.Model):
    def __init__(self,scope: str="uk",k:int =16,stride:int=2,reg:float=0.0005,norm:str="instance"):
        super(Ck, self).__init__(name=scope)
        self.norm =norm
        self.conv1 = conv(numout=k,kernel_size=3,strides=[ stride, stride],kernel_regularizer=reg,padding='same',name='conv')
        if norm is 'instance':
            self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
            self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
        elif norm is 'bn':
            self.bn1 = bn(name='bn')
    def call(self,x,training=False,slope=0.2):
        x = self.conv1(x)
        if self.norm is 'instance':
            mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
            x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
        elif self.norm is 'bn':
            x = self.bn1(x,training=training)
        x = tf.nn.leaky_relu(x,slope)
        return x 

class last_conv(keras.Model):
    def __init__(self,scope: str="last_conv",reg:float=0.0005):
        super(last_conv, self).__init__(name=scope)
        self.conv1 = conv(numout=1,kernel_size=4,kernel_regularizer=reg,padding='same',name='conv')
    def call(self,x,use_sigmoid=False):
        x = self.conv1(x)
        if use_sigmoid:
            output = tf.nn.sigmoid(x)
        return x 

class Discriminator(keras.Model):
  def __init__(self,scope: str="Discriminator",reg:float=0.0005,norm:str="instance"):
      super(Discriminator, self).__init__(name=scope)
      self.ck1 = Ck(scope="C64",k=64,reg=reg,norm=norm)
      self.ck2 = Ck(scope="C128",k=128,reg=reg,norm=norm)
      self.ck3 = Ck(scope="C256",k=256,reg=reg,norm=norm)
      self.ck4 = Ck(scope="C512",k=512,reg=reg,norm=norm)
      self.last_conv = last_conv(scope="output",reg=reg)
  def call(self,x,training=False,use_sigmoid=False,slope=0.2):
      x=self.ck1(x,training=training,slope=slope)
      x=self.ck2(x,training=training,slope=slope)
      x=self.ck3(x,training=training,slope=slope)
      x=self.ck4(x,training=training,slope=slope)
      x=self.last_conv(x,use_sigmoid=use_sigmoid)
      return x

class Generator(keras.Model):
    def __init__(self,scope: str="Generator",ngf:int=64,reg:float=0.0005,norm:str="instance",more:bool=True):
        super(Generator, self).__init__(name=scope)
        self.c7s1_32=c7s1_k(scope="c7s1_32",k=ngf,reg=reg,norm=norm)
        self.d64 = dk(scope="d64",k=2*ngf,reg=reg,norm=norm)
        self.d128 = dk(scope="d128",k=4*ngf,reg=reg,norm=norm) 
        if more:
            self.res_output=n_res_blocks(scope="8_res_blocks",n=8,k=4*ngf,reg=reg,norm=norm)
        else:
            self.res_output=n_res_blocks(scope="6_res_blocks",n=6,k=4*ngf,reg=reg,norm=norm)
        self.u64=uk(scope="u64",k=2*ngf,reg=reg,norm=norm)
        self.u32=uk(scope="u32",k=ngf,reg=reg,norm=norm)
        self.outconv = c7s1_k(scope="output",k=3,reg=reg,norm=norm)
    def call(self,x,training=False):
        x = self.c7s1_32(x,training=training,activation='Relu')
        x = self.d64(x,training=training)
        x = self.d128(x,training=training)
        x = self.res_output(x,training=training)
        x = self.u64(x,training=training)
        x = self.u32(x,training=training)
        x = self.outconv(x,training=training,activation='tanh')
        return x    
import tensorflow_datasets as tfds
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

def decodefortrain(img):
    img = tf.io.read_file(img)  #该函数用于读取并输出输入文件名的全部内容,返回一个string类型的张量
    img = tf.image.decode_png(img,channels = 3) #解码成3通道rpg行驶
    img = tf.cast(img,dtype=tf.float32) #执行tensorflow中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。
    scale = tf.random.uniform([1],minval = 0.25,maxval = 0.5,dtype = tf.float32) 
    hi = tf.floor(scale*1024) #向下取整
    wi = tf.floor(scale*2048)
    s = tf.concat([hi,wi],0) # 沿着 0轴 拼接起来
    s = tf.cast(s,dtype=tf.int32)
    img = tf.compat.v1.image.resize_images(img,s,method = 0, align_corners = True) # 用特殊的方法放缩图片大小
    img = tf.image.random_crop(img,[256,512,3]) #将张量随机裁剪为给定大小。
    img = tf.image.random_flip_left_right(img) #水平随机翻转图像(从左到右)。
    #img = tf.image.convert_image_dtype(img,dtype = tf.float32,saturate=True)
    img = (img/255)*2-1    
    return img


def source_data(batchsize=1):
    train=tf.data.Dataset.from_tensor_slices(train_horses).shuffle(50).map(decodefortrain).batch(batchsize)
    return train
    
def target_data(batchsize=1):
    train=tf.data.Dataset.from_tensor_slices(train_zebras).shuffle(50).map(decodefortrain).batch(batchsize)
    return train

train_horses = train_horses.map(
    preprocess_image_train).cache().shuffle(
    50).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train).cache().shuffle(
    50).batch(1)

test_horses = test_horses.map(
    preprocess_image_test).cache().shuffle(
    50).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test).cache().shuffle(
    50).batch(1)
def lr_sch(epoch):
    if epoch+1<=100:
        lr = tf.constant(2e-4)
    else:
        lr = tf.constant(2e-4-2e-8)*(1-(epoch+1-100)/100)+2e-8
    return lr

def train_step(G,F,D_Y,D_X,source,target,generator_g_optimizer,generator_f_optimizer,discriminator_x_optimizer,discriminator_y_optimizer,train_loss,lambda1,lambda2):
    with tf.GradientTape(persistent=True) as tape:
      # 生成器
        fake_y = G(source,training=True) # X->G(X)
        cycled_x = F(fake_y,training=True) # G(X)-> F(G(X))

        fake_x = F(target,training=True) # Y-> F(Y)
        cycled_y = G(fake_x,training=True) # F(Y) -> G(F(Y))
      # 用于计算一致性损失
        same_x = F(source,training=True) 
        same_y = G(target,training=True)
      # 判别器
        disc_real_x = D_X(source,training=True) #X->D(X)
        disc_real_y = D_Y(target,training=True) #Y->D(Y)

        disc_fake_x = D_X(fake_x,training=True) #F(Y) -> D_X(F(Y)))
        disc_fake_y = D_Y(fake_y,training=True) #G(X) -> D_Y(G(X))
        #generator loss using lsgan
        gen_g_loss = tf.reduce_mean(tf.math.squared_difference(disc_fake_y, 1)) # (D_Y(G(X))-0.9)^2
        gen_f_loss = tf.reduce_mean(tf.math.squared_difference(disc_fake_x, 1)) # (D_X(F(Y))-0.9)^2
        #cycle loss     
        total_cycle_loss = tf.reduce_mean(tf.abs(source-cycled_x))*lambda1 +  tf.reduce_mean(tf.abs(target-cycled_y))*lambda2 # |F(G(X))-X|*lambda1 + |G(F(Y))-Y|*lambda2
        #总生成器损失 = 对抗损失 + 循环损失
        total_gen_g_loss = gen_g_loss + total_cycle_loss + tf.reduce_mean(tf.abs(target - same_y))*0.5*lambda2
        total_gen_f_loss = gen_f_loss + total_cycle_loss + tf.reduce_mean(tf.abs(source - same_x))*0.5*lambda1
        #diss loss
        disc_x_loss = 0.5*tf.reduce_mean(tf.math.squared_difference(disc_real_x,1)+tf.math.square(disc_fake_x))
        disc_y_loss = 0.5*tf.reduce_mean(tf.math.squared_difference(disc_real_y,1)+tf.math.square(disc_fake_y))     
        #reg loss
        #loss_reg = tf.reduce_sum(G.losses)+tf.reduce_sum(F.losses)+tf.reduce_sum(D_Y.losses)+tf.reduce_sum(D_X.losses)


    train_loss[0](total_gen_g_loss)
    train_loss[1](total_gen_f_loss)
    train_loss[2](disc_x_loss)
    train_loss[3](disc_y_loss)
    generator_g_gradients = tape.gradient(total_gen_g_loss,G.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss,F.trainable_variables)
    discriminator_x_gradients = tape.gradient(disc_x_loss,D_X.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss,D_Y.trainable_variables)

    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, G.trainable_variables))
    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, F.trainable_variables))
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,D_X.trainable_variables))
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,D_Y.trainable_variables))
    

    return fake_y, cycled_x,fake_x,cycled_y,same_x,same_y, total_gen_g_loss,total_gen_f_loss,disc_x_loss,disc_y_loss,generator_g_optimizer.iterations
weight_decay = 0
learning_rate = 2e-4
batch_size = 1
epoch = 100
#log_dir = FLAGS.logdir 
use_lsgan = True
norm = 'instance'
lambda1 = 10
lambda2 = 10   
beta1 = 0.5
ngf = 64
G = Generator('G',ngf,weight_decay,norm=norm,more=True)   
F = Generator('F',ngf,weight_decay,norm=norm,more=True)   
D_Y = Discriminator('D_Y',reg=weight_decay,norm=norm)
D_X = Discriminator('D_X',reg=weight_decay,norm=norm)   
forbuild=np.random.rand(1,256,256,3).astype(np.float32)
built=G(forbuild)
built=F(forbuild)
built=D_Y(forbuild)
built=D_X(forbuild)  

#source_data=source_data(batch_size)
source_data=source1.repeat(20)
#target_data=target_data(batch_size)

target_it=iter(target1.repeat(-1))

train_loss_G = tf.keras.metrics.Mean('train_loss_G', dtype=tf.float32)
train_loss_F = tf.keras.metrics.Mean('train_loss_F', dtype=tf.float32)
train_loss_DX = tf.keras.metrics.Mean('train_loss_DX', dtype=tf.float32)    
train_loss_DY = tf.keras.metrics.Mean('train_loss_DY', dtype=tf.float32)        
train_loss = [train_loss_G,train_loss_F,train_loss_DX,train_loss_DY]
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

#train_summary_writer = tf.summary.create_file_writer(log_dir)

#ckpt = tf.train.Checkpoint(G=G,F=F,D_X=D_X,D_Y=D_Y,generator_g_optimizer=generator_g_optimizer,generator_f_optimizer=generator_f_optimizer,discriminator_x_optimizer=discriminator_x_optimizer,discriminator_y_optimizer=discriminator_y_optimizer)
#ckpt_manager = tf.train.CheckpointManager(ckpt, log_dir, max_to_keep=10)
start =0
lr = lr_sch(start)
#if ckpt_manager.latest_checkpoint:
    #ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()  
    #start=int(ckpt_manager.latest_checkpoint.split('-')[-1])
  # lr = lr_sch(start)
  # print ('Latest checkpoint restored!!')
for ep in range(start,epoch,1):
    print('Epoch:'+str(ep+1))   
    for step,  source in enumerate(source_data):
        target = next(target_it)        
        fake_y, cycled_x,fake_x,cycled_y,same_x,same_y, total_gen_g_loss,total_gen_f_loss,disc_x_loss,disc_y_loss,steps = train_step(G,F,D_Y,D_X,source,target,generator_g_optimizer,generator_f_optimizer,discriminator_x_optimizer,discriminator_y_optimizer,train_loss,lambda1,lambda2)
        print('Step: '+str(steps.numpy())+' , G loss: '+str(total_gen_g_loss.numpy())+' , F loss: '+str(total_gen_f_loss.numpy())+' , D_X loss: '+str(disc_x_loss.numpy())+' , D_Y loss: '+str(disc_y_loss.numpy()))
        
    train_loss[1].reset_states()
    train_loss[2].reset_states()
    train_loss[3].reset_states()
    lr = lr_sch(ep)
    generator_g_optimizer.learning_rate = lr
    generator_f_optimizer.learning_rate = lr
    discriminator_x_optimizer.learning_rate = lr
    discriminator_y_optimizer.learning_rate = lr
    #ckpt_save_path = ckpt_manager.save()
print("Traing is over!")        

我用的代码是借鉴大神们写的,参考的资料如下:
https://github.com/Katexiang/CycleGAN

相关文章

网友评论

      本文标题:CycleGAN 代码

      本文链接:https://www.haomeiwen.com/subject/vzeightx.html