五一前把CycleGAN的文章看了,今天来把代码补上,遗憾的是这个程序虽然能运行,但是我由于计算资源的限制,没法跑完200个epoch,所以没办法把最终的结果展示出来。
tensorflow 官方给出了代码,但是里面用的是Unet做生成器和鉴别器,与原文有所出入,其中关于对抗损失函数一块与原文出入比较大
tensorflow 官方的传送门:https://www.tensorflow.org/tutorials/generative/cyclegan
我用的代码如下:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.regularizers import l2
import numpy as np
import math
epsilon = 1e-5
def conv(numout, kernel_size=3, strides=1,kernel_regularizer = 0.0005, padding='same',use_bias=False,name='conv'):
return tf.keras.layers.Conv2D(name=name,filters=numout,kernel_size=kernel_size,strides=strides,padding=padding,use_bias=use_bias,kernel_regularizer=l2(kernel_regularizer),kernel_initializer=tf.random_normal_initializer(stddev=0.1))
def convt(numout,kernel_size=3,strides=1,kernel_regularizer=0.0005,padding='same',use_bias=False,name='conv'):
return tf.keras.layers.Conv2DTranspose(name=name,filters=numout, kernel_size=kernel_size,strides=strides, padding=padding,use_bias=use_bias, kernel_regularizer=l2(kernel_regularizer),kernel_initializer=tf.random_normal_initializer(stddev=0.1))
def bn(name,momentum=0.9):
return tf.keras.layers.BatchNormalization(name=name,momentum=momentum)
class c7s1_k(keras.Model):
def __init__(self,scope: str="c7s1_k",k:int =16,reg:float=0.0005,norm:str="instance"):
super(c7s1_k, self).__init__(name=scope)
self.conv1 = conv(numout=k,kernel_size=7,kernel_regularizer=reg,padding='valid',name='conv')
self.norm =norm
if norm is 'instance':
self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
elif norm is 'bn':
self.bn1 = bn(name='bn')
def call(self,x,training=False,activation='Relu'):
x = tf.pad(x, [[0,0],[3,3],[3,3],[0,0]], 'REFLECT')
x = self.conv1(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
elif self.norm is 'bn':
x = self.bn1(x,training=training)
if activation is 'Relu':
x = tf.nn.relu(x)
else:
x = tf.nn.tanh(x)
return x
class dk(keras.Model):
def __init__(self,scope: str="dk",k:int =16,reg:float=0.0005,norm:str="instance"):
super(dk, self).__init__(name=scope)
self.norm =norm
self.conv1 = conv(numout=k,kernel_size=3,strides=[2, 2],kernel_regularizer=reg,padding='same',name='conv')
if norm is 'instance':
self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
elif norm is 'bn':
self.bn1 = bn(name='bn')
def call(self,x,training=False):
x = self.conv1(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
elif self.norm is 'bn':
x = self.bn1(x,training=training)
x = tf.nn.relu(x)
return x
class Rk(keras.Model):
def __init__(self,scope: str="Rk",k:int =16,reg:float=0.0005,norm:str="instance"):
super(Rk, self).__init__(name=scope)
self.norm =norm
self.conv1 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='layer1/conv')
if norm is 'instance':
self.scale1 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer1/scale')
self.offset1 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer1/offset')
self.scale2 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer2/scale')
self.offset2 = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='layer2/offset')
elif norm is 'bn':
self.bn1 = bn(name='layer1/bn')
self.bn2 = bn(name='layer2/bn')
self.conv2 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='layer2/conv')
def call(self,x,training=False):
inputs = x
x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
x = self.conv1(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale1 * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset1
elif self.norm is 'bn':
x = self.bn1(x,training=training)
x = tf.nn.relu(x)
x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
x = self.conv2(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale2 * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset2
elif self.norm is 'bn':
x = self.bn2(x,training=training)
return x + inputs
class n_res_blocks(keras.Model):
def __init__(self,scope: str="n_res_blocks",n:int =6,k:int=16,reg:float=0.0005,norm:str="instance"):
super(n_res_blocks, self).__init__(name=scope)
self.group=[]
self.norm =norm
for i in range(n):
self.group.append(Rk(scope='Rk_'+str(i+1),k=k,reg=reg,norm=norm))
def call(self,x,training=False):
for i in range(len(self.group)):
x = self.group[i](x,training=training)
return x
class uk(keras.Model):
def __init__(self,scope: str="uk",k:int =16,reg:float=0.0005,norm:str="instance"):
super(uk, self).__init__(name=scope)
self.norm =norm
#self.conv1 = conv(numout=k,kernel_size=3,kernel_regularizer=reg,padding='valid',name='conv')
self.conv1 = convt(numout=k,kernel_size=3,strides=[ 2 , 2 ],kernel_regularizer=reg,padding='same',name='conv')
if norm is 'instance':
self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
elif norm is 'bn':
self.bn1 = bn(name='bn')
def call(self,x,training=False):
#height = x.shape[1]
#width = x.shape[2]
#x=tf.compat.v1.image.resize_images(x, [2*height,2*width],method = 0, align_corners = True)
#x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')
x = self.conv1(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
elif self.norm is 'bn':
x = self.bn1(x,training=training)
x = tf.nn.relu(x)
return x
class Ck(keras.Model):
def __init__(self,scope: str="uk",k:int =16,stride:int=2,reg:float=0.0005,norm:str="instance"):
super(Ck, self).__init__(name=scope)
self.norm =norm
self.conv1 = conv(numout=k,kernel_size=3,strides=[ stride, stride],kernel_regularizer=reg,padding='same',name='conv')
if norm is 'instance':
self.scale = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='scale')
self.offset = tf.Variable(initial_value =tf.random_normal_initializer(stddev=0.1)(shape=[k]),name='offset')
elif norm is 'bn':
self.bn1 = bn(name='bn')
def call(self,x,training=False,slope=0.2):
x = self.conv1(x)
if self.norm is 'instance':
mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
x = self.scale * tf.math.divide(x - mean, tf.math.sqrt(var + epsilon)) + self.offset
elif self.norm is 'bn':
x = self.bn1(x,training=training)
x = tf.nn.leaky_relu(x,slope)
return x
class last_conv(keras.Model):
def __init__(self,scope: str="last_conv",reg:float=0.0005):
super(last_conv, self).__init__(name=scope)
self.conv1 = conv(numout=1,kernel_size=4,kernel_regularizer=reg,padding='same',name='conv')
def call(self,x,use_sigmoid=False):
x = self.conv1(x)
if use_sigmoid:
output = tf.nn.sigmoid(x)
return x
class Discriminator(keras.Model):
def __init__(self,scope: str="Discriminator",reg:float=0.0005,norm:str="instance"):
super(Discriminator, self).__init__(name=scope)
self.ck1 = Ck(scope="C64",k=64,reg=reg,norm=norm)
self.ck2 = Ck(scope="C128",k=128,reg=reg,norm=norm)
self.ck3 = Ck(scope="C256",k=256,reg=reg,norm=norm)
self.ck4 = Ck(scope="C512",k=512,reg=reg,norm=norm)
self.last_conv = last_conv(scope="output",reg=reg)
def call(self,x,training=False,use_sigmoid=False,slope=0.2):
x=self.ck1(x,training=training,slope=slope)
x=self.ck2(x,training=training,slope=slope)
x=self.ck3(x,training=training,slope=slope)
x=self.ck4(x,training=training,slope=slope)
x=self.last_conv(x,use_sigmoid=use_sigmoid)
return x
class Generator(keras.Model):
def __init__(self,scope: str="Generator",ngf:int=64,reg:float=0.0005,norm:str="instance",more:bool=True):
super(Generator, self).__init__(name=scope)
self.c7s1_32=c7s1_k(scope="c7s1_32",k=ngf,reg=reg,norm=norm)
self.d64 = dk(scope="d64",k=2*ngf,reg=reg,norm=norm)
self.d128 = dk(scope="d128",k=4*ngf,reg=reg,norm=norm)
if more:
self.res_output=n_res_blocks(scope="8_res_blocks",n=8,k=4*ngf,reg=reg,norm=norm)
else:
self.res_output=n_res_blocks(scope="6_res_blocks",n=6,k=4*ngf,reg=reg,norm=norm)
self.u64=uk(scope="u64",k=2*ngf,reg=reg,norm=norm)
self.u32=uk(scope="u32",k=ngf,reg=reg,norm=norm)
self.outconv = c7s1_k(scope="output",k=3,reg=reg,norm=norm)
def call(self,x,training=False):
x = self.c7s1_32(x,training=training,activation='Relu')
x = self.d64(x,training=training)
x = self.d128(x,training=training)
x = self.res_output(x,training=training)
x = self.u64(x,training=training)
x = self.u32(x,training=training)
x = self.outconv(x,training=training,activation='tanh')
return x
import tensorflow_datasets as tfds
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
def decodefortrain(img):
img = tf.io.read_file(img) #该函数用于读取并输出输入文件名的全部内容,返回一个string类型的张量
img = tf.image.decode_png(img,channels = 3) #解码成3通道rpg行驶
img = tf.cast(img,dtype=tf.float32) #执行tensorflow中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。
scale = tf.random.uniform([1],minval = 0.25,maxval = 0.5,dtype = tf.float32)
hi = tf.floor(scale*1024) #向下取整
wi = tf.floor(scale*2048)
s = tf.concat([hi,wi],0) # 沿着 0轴 拼接起来
s = tf.cast(s,dtype=tf.int32)
img = tf.compat.v1.image.resize_images(img,s,method = 0, align_corners = True) # 用特殊的方法放缩图片大小
img = tf.image.random_crop(img,[256,512,3]) #将张量随机裁剪为给定大小。
img = tf.image.random_flip_left_right(img) #水平随机翻转图像(从左到右)。
#img = tf.image.convert_image_dtype(img,dtype = tf.float32,saturate=True)
img = (img/255)*2-1
return img
def source_data(batchsize=1):
train=tf.data.Dataset.from_tensor_slices(train_horses).shuffle(50).map(decodefortrain).batch(batchsize)
return train
def target_data(batchsize=1):
train=tf.data.Dataset.from_tensor_slices(train_zebras).shuffle(50).map(decodefortrain).batch(batchsize)
return train
train_horses = train_horses.map(
preprocess_image_train).cache().shuffle(
50).batch(1)
train_zebras = train_zebras.map(
preprocess_image_train).cache().shuffle(
50).batch(1)
test_horses = test_horses.map(
preprocess_image_test).cache().shuffle(
50).batch(1)
test_zebras = test_zebras.map(
preprocess_image_test).cache().shuffle(
50).batch(1)
def lr_sch(epoch):
if epoch+1<=100:
lr = tf.constant(2e-4)
else:
lr = tf.constant(2e-4-2e-8)*(1-(epoch+1-100)/100)+2e-8
return lr
def train_step(G,F,D_Y,D_X,source,target,generator_g_optimizer,generator_f_optimizer,discriminator_x_optimizer,discriminator_y_optimizer,train_loss,lambda1,lambda2):
with tf.GradientTape(persistent=True) as tape:
# 生成器
fake_y = G(source,training=True) # X->G(X)
cycled_x = F(fake_y,training=True) # G(X)-> F(G(X))
fake_x = F(target,training=True) # Y-> F(Y)
cycled_y = G(fake_x,training=True) # F(Y) -> G(F(Y))
# 用于计算一致性损失
same_x = F(source,training=True)
same_y = G(target,training=True)
# 判别器
disc_real_x = D_X(source,training=True) #X->D(X)
disc_real_y = D_Y(target,training=True) #Y->D(Y)
disc_fake_x = D_X(fake_x,training=True) #F(Y) -> D_X(F(Y)))
disc_fake_y = D_Y(fake_y,training=True) #G(X) -> D_Y(G(X))
#generator loss using lsgan
gen_g_loss = tf.reduce_mean(tf.math.squared_difference(disc_fake_y, 1)) # (D_Y(G(X))-0.9)^2
gen_f_loss = tf.reduce_mean(tf.math.squared_difference(disc_fake_x, 1)) # (D_X(F(Y))-0.9)^2
#cycle loss
total_cycle_loss = tf.reduce_mean(tf.abs(source-cycled_x))*lambda1 + tf.reduce_mean(tf.abs(target-cycled_y))*lambda2 # |F(G(X))-X|*lambda1 + |G(F(Y))-Y|*lambda2
#总生成器损失 = 对抗损失 + 循环损失
total_gen_g_loss = gen_g_loss + total_cycle_loss + tf.reduce_mean(tf.abs(target - same_y))*0.5*lambda2
total_gen_f_loss = gen_f_loss + total_cycle_loss + tf.reduce_mean(tf.abs(source - same_x))*0.5*lambda1
#diss loss
disc_x_loss = 0.5*tf.reduce_mean(tf.math.squared_difference(disc_real_x,1)+tf.math.square(disc_fake_x))
disc_y_loss = 0.5*tf.reduce_mean(tf.math.squared_difference(disc_real_y,1)+tf.math.square(disc_fake_y))
#reg loss
#loss_reg = tf.reduce_sum(G.losses)+tf.reduce_sum(F.losses)+tf.reduce_sum(D_Y.losses)+tf.reduce_sum(D_X.losses)
train_loss[0](total_gen_g_loss)
train_loss[1](total_gen_f_loss)
train_loss[2](disc_x_loss)
train_loss[3](disc_y_loss)
generator_g_gradients = tape.gradient(total_gen_g_loss,G.trainable_variables)
generator_f_gradients = tape.gradient(total_gen_f_loss,F.trainable_variables)
discriminator_x_gradients = tape.gradient(disc_x_loss,D_X.trainable_variables)
discriminator_y_gradients = tape.gradient(disc_y_loss,D_Y.trainable_variables)
generator_g_optimizer.apply_gradients(zip(generator_g_gradients, G.trainable_variables))
generator_f_optimizer.apply_gradients(zip(generator_f_gradients, F.trainable_variables))
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,D_X.trainable_variables))
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,D_Y.trainable_variables))
return fake_y, cycled_x,fake_x,cycled_y,same_x,same_y, total_gen_g_loss,total_gen_f_loss,disc_x_loss,disc_y_loss,generator_g_optimizer.iterations
weight_decay = 0
learning_rate = 2e-4
batch_size = 1
epoch = 100
#log_dir = FLAGS.logdir
use_lsgan = True
norm = 'instance'
lambda1 = 10
lambda2 = 10
beta1 = 0.5
ngf = 64
G = Generator('G',ngf,weight_decay,norm=norm,more=True)
F = Generator('F',ngf,weight_decay,norm=norm,more=True)
D_Y = Discriminator('D_Y',reg=weight_decay,norm=norm)
D_X = Discriminator('D_X',reg=weight_decay,norm=norm)
forbuild=np.random.rand(1,256,256,3).astype(np.float32)
built=G(forbuild)
built=F(forbuild)
built=D_Y(forbuild)
built=D_X(forbuild)
#source_data=source_data(batch_size)
source_data=source1.repeat(20)
#target_data=target_data(batch_size)
target_it=iter(target1.repeat(-1))
train_loss_G = tf.keras.metrics.Mean('train_loss_G', dtype=tf.float32)
train_loss_F = tf.keras.metrics.Mean('train_loss_F', dtype=tf.float32)
train_loss_DX = tf.keras.metrics.Mean('train_loss_DX', dtype=tf.float32)
train_loss_DY = tf.keras.metrics.Mean('train_loss_DY', dtype=tf.float32)
train_loss = [train_loss_G,train_loss_F,train_loss_DX,train_loss_DY]
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
#train_summary_writer = tf.summary.create_file_writer(log_dir)
#ckpt = tf.train.Checkpoint(G=G,F=F,D_X=D_X,D_Y=D_Y,generator_g_optimizer=generator_g_optimizer,generator_f_optimizer=generator_f_optimizer,discriminator_x_optimizer=discriminator_x_optimizer,discriminator_y_optimizer=discriminator_y_optimizer)
#ckpt_manager = tf.train.CheckpointManager(ckpt, log_dir, max_to_keep=10)
start =0
lr = lr_sch(start)
#if ckpt_manager.latest_checkpoint:
#ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
#start=int(ckpt_manager.latest_checkpoint.split('-')[-1])
# lr = lr_sch(start)
# print ('Latest checkpoint restored!!')
for ep in range(start,epoch,1):
print('Epoch:'+str(ep+1))
for step, source in enumerate(source_data):
target = next(target_it)
fake_y, cycled_x,fake_x,cycled_y,same_x,same_y, total_gen_g_loss,total_gen_f_loss,disc_x_loss,disc_y_loss,steps = train_step(G,F,D_Y,D_X,source,target,generator_g_optimizer,generator_f_optimizer,discriminator_x_optimizer,discriminator_y_optimizer,train_loss,lambda1,lambda2)
print('Step: '+str(steps.numpy())+' , G loss: '+str(total_gen_g_loss.numpy())+' , F loss: '+str(total_gen_f_loss.numpy())+' , D_X loss: '+str(disc_x_loss.numpy())+' , D_Y loss: '+str(disc_y_loss.numpy()))
train_loss[1].reset_states()
train_loss[2].reset_states()
train_loss[3].reset_states()
lr = lr_sch(ep)
generator_g_optimizer.learning_rate = lr
generator_f_optimizer.learning_rate = lr
discriminator_x_optimizer.learning_rate = lr
discriminator_y_optimizer.learning_rate = lr
#ckpt_save_path = ckpt_manager.save()
print("Traing is over!")
我用的代码是借鉴大神们写的,参考的资料如下:
https://github.com/Katexiang/CycleGAN
网友评论