2018-03-13-Cycle GAN with tensor

2018-03-13-Cycle GAN with tensor

作者: 老王_5b2d | 来源:发表于2018-03-14 15:40 被阅读0次


    Class Cycle Gan


    +discriminator (ResNet/VGG19)
    +generator (ResNet/VGG19)
    +criterion (MAE/LS)


    +fake_A_from_real_B # generator, to generate fake A from real B
    +fake_A_from_fake_B # generate fake A from fake B
    +fake_B_from_real_A # generate fake B from real A
    +fake_B_from_fake_A # generate fake B from fake A
    #what the difference between fake_A_from_real_B or fake_A_from_fake_B and fake_A_sample

    in discriminator

    +D_A_real # call discriminator, to discriminate real_A
    ++D_A_real_loss # call criterion to compare D_A_real and ground truth
    +D_B_real # call discriminator, to discriminate real_B
    ++D_B_real_loss # call criterion to compare D_B_real and ground truth
    +D_A_fake # call discriminator, to discriminate fake_A_from_real_B
    ++D_A_fake_loss # call criterion to compare D_A_fake and ground truth
    #D_A_loss = f.(D_A_real_loss, D_A_fake_loss)
    +D_B_fake # call discriminator, to discriminate fake_B_from_real_A
    ++D_B_fake_loss # call criterion to compare D_B_fake and ground truth
    #D_B_loss = f.(D_B_real_loss, D_B_fake_loss)
    +D_A_fake_sample # call discriminator, to discriminate fake_A_sample
    ++D_A_fake_sample_loss # call criterion to compare D_A_fake_sample and ground truth
    +D_B_fake_sample # call discriminator, to discriminate fake_B_sample
    ++D_B_fake_sample_loss # call criterion to compare D_B_fake_sample and ground truth
    #D_loss = f.(D_A_loss, D_B_loss)
    #why not use D_B_fake_from_fake_A? because here only fake_A_from_fake_B be used

    in generator

    G_loss_A2B # call the criterion to compare DB_fake and fake ground truth, because the goal of generator is to cheat the D_B.
    so that G_loss_A2B= f.((DE_fake,fake ground truth),(λ, real_A,fake_A_from_fake_B),(λ,real_B,fake_B_from_fake_A)) # λ controls the relative importance of the two objectives, fake_B_from_fake_A more like B or fake_A_from_fake_B more like A

    G_loss_B2A # f(DA_fake,real_A,fake_A_from_fake_B,real_B,fake_B_from_fake_A)
    G_loss_All # f(DA_fake, DB_fake, real_A,fake_A_from_fake_B,real_B,fake_B_from_fake_A)

    loss function
    loss function
    full objective
    training details

    summary and record

    in discriminator
    +D_A_real_loss_summary # call tf.summary.scalar to summary and record the D_A_real_loss
    +D_A_fake_loss_summary # call tf.summary.scalar to summary and record the D_A_fake_loss
    +D_A_loss_summary # call tf.summary.scalar to summary and record the D_A_loss
    +D_B_real_loss_summary # call tf.summary.scalar to summary and record the D_B_real_loss
    +D_B_fake_loss_summary # call tf.summary.scalar to summary and record the D_B_fake_loss
    +D_B_loss_summary # call tf.summary.scalar to summary and record the D_B_loss
    +D_loss_summary # call tf.summary.scalar to summary and record the D_loss

    in generator
    +G_loss_A2B_summary # call tf.summary.scalar to summary and record the G_loss_A2B
    +G_loss_B2A_summary # call tf.summary.scalar to summary and record the G_loss_B2A
    +G_loss_summary # call tf.summary.scalar to summary and record the G_loss

    in train.py

    #Update G network and record fake outputs
    fake_A, fake_B, _, summary_str = self.sess.run(
    [self.fake_A, self.fake_B, self.g_optim, self.g_sum],
    feed_dict={self.real_data: batch_images, self.lr: lr})

    #Update D network
    _, summary_str = self.sess.run(
    [self.d_optim, self.d_sum],
    feed_dict={self.real_data: batch_images,
    self.fake_A_sample: fake_A,
    self.fake_B_sample: fake_B,self.lr: lr})

    to be continued



        本文标题:2018-03-13-Cycle GAN with tensor
