Reference:
Github
Class Cycle Gan
_init
+discriminator (ResNet/VGG19)
++D_A
++D_B
+generator (ResNet/VGG19)
++G_A2B
++G_B2A
+criterion (MAE/LS)
_judge
+real_A
++loss
+real_B
++loss
+fake_A_from_real_B # generator, to generate fake A from real B
+fake_A_from_fake_B # generate fake A from fake B
++loss
+fake_B_from_real_A # generate fake B from real A
++loss
+fake_B_from_fake_A # generate fake B from fake A
++loss
+fake_A_sample
#what the difference between fake_A_from_real_B or fake_A_from_fake_B and fake_A_sample
+fake_B_sample
in discriminator
+D_A_real # call discriminator, to discriminate real_A
++D_A_real_loss # call criterion to compare D_A_real and ground truth
+D_B_real # call discriminator, to discriminate real_B
++D_B_real_loss # call criterion to compare D_B_real and ground truth
+D_A_fake # call discriminator, to discriminate fake_A_from_real_B
++D_A_fake_loss # call criterion to compare D_A_fake and ground truth
#D_A_loss = f.(D_A_real_loss, D_A_fake_loss)
+D_B_fake # call discriminator, to discriminate fake_B_from_real_A
++D_B_fake_loss # call criterion to compare D_B_fake and ground truth
#D_B_loss = f.(D_B_real_loss, D_B_fake_loss)
+D_A_fake_sample # call discriminator, to discriminate fake_A_sample
++D_A_fake_sample_loss # call criterion to compare D_A_fake_sample and ground truth
+D_B_fake_sample # call discriminator, to discriminate fake_B_sample
++D_B_fake_sample_loss # call criterion to compare D_B_fake_sample and ground truth
#D_loss = f.(D_A_loss, D_B_loss)
#why not use D_B_fake_from_fake_A? because here only fake_A_from_fake_B be used
in generator
G_loss_A2B # call the criterion to compare DB_fake and fake ground truth, because the goal of generator is to cheat the D_B.
so that G_loss_A2B= f.((DE_fake,fake ground truth),(λ, real_A,fake_A_from_fake_B),(λ,real_B,fake_B_from_fake_A)) # λ controls the relative importance of the two objectives, fake_B_from_fake_A more like B or fake_A_from_fake_B more like A
G_loss_B2A # f(DA_fake,real_A,fake_A_from_fake_B,real_B,fake_B_from_fake_A)
G_loss_All # f(DA_fake, DB_fake, real_A,fake_A_from_fake_B,real_B,fake_B_from_fake_A)
loss function
full objective
training details
summary and record
in discriminator
+D_A_real_loss_summary # call tf.summary.scalar to summary and record the D_A_real_loss
+D_A_fake_loss_summary # call tf.summary.scalar to summary and record the D_A_fake_loss
+D_A_loss_summary # call tf.summary.scalar to summary and record the D_A_loss
+D_B_real_loss_summary # call tf.summary.scalar to summary and record the D_B_real_loss
+D_B_fake_loss_summary # call tf.summary.scalar to summary and record the D_B_fake_loss
+D_B_loss_summary # call tf.summary.scalar to summary and record the D_B_loss
+D_loss_summary # call tf.summary.scalar to summary and record the D_loss
in generator
+G_loss_A2B_summary # call tf.summary.scalar to summary and record the G_loss_A2B
+G_loss_B2A_summary # call tf.summary.scalar to summary and record the G_loss_B2A
+G_loss_summary # call tf.summary.scalar to summary and record the G_loss
in train.py
#Update G network and record fake outputs
fake_A, fake_B, _, summary_str = self.sess.run(
[self.fake_A, self.fake_B, self.g_optim, self.g_sum],
feed_dict={self.real_data: batch_images, self.lr: lr})
#Update D network
_, summary_str = self.sess.run(
[self.d_optim, self.d_sum],
feed_dict={self.real_data: batch_images,
self.fake_A_sample: fake_A,
self.fake_B_sample: fake_B,self.lr: lr})
to be continued
网友评论