使用TFGAN库之前需要:
import tensorflow as tf
tfgan = tf.contrib.gan
TFGAN库的训练有4个步骤,如下:
- 创建网络模型,使用tfgan.gan_model;
- 加入loss,使用tfgan.gan_loss;
- 创建train ops,使用tfgan.gan_train_loss;
- 运行train ops,使用tfgan,gan_train。
其中,tfgan.gan_model的使用要自己定义generator_fn和discriminator_fn(gan_model只是起个组装作用)。另外一个重要的点是loss的设定,TFGAN库提供了不少已经定义好的loss,包括:
- 原始GAN的损失函数,tfgan.losses.minimax_discriminator_loss和tfgan.losses.minimax_generator_loss
- 原始GAN的损失函数的改进版,tfgan.losses.modified_discriminator_loss和tfgan.losses.modified_generator_loss
- ACGAN的损失函数,tfgan.losses.acgan_discriminator_loss和tfgan.losses.acgan_generator_loss
- LSGAN的损失函数,tfgan.losses.least_squares_discriminator_loss和tfgan.losses.least_squares_generator_loss
- WGAN和GWAN-GP的损失函数,tfgan.losses.wasserstein_discriminator_loss、tfgan.losses.wasserstein_generator_loss和tfgan.wasserstein_gradient_penalty
- INFOGAN的损失函数(INFOGAN的损失函数是原始GAN模型的损失函数和互信息损失的结合),tfgan.losses.mutual_information_penalty
- CycleGAN的损失函数,tfgan.losses.cycle_consistency_loss
不过需要注意的是,tfgan.wasserstein_gradient_penalty、tfgan.losses.mutual_information_penalty和tfgan.losses.cycle_consistency_loss只是对应模型的损失函数的一部分,这几个函数都是封装在其他更高级的函数中。在tfgan.gan_loss()中有两个参数,一个是gradient_penalty_weight,对应的就是WGAN-GP中的GP,一个是mutual_information_penalty_weight,对应的就是INFOGAN中的互信息损失。而tfgan.cycle_consistency_loss则是封装在tfgan.cyclegan_loss中。
一般的GAN模型其实用tfgan.gan_model和tfgan.gan_loss已经够了,但是为了方便创建一些具有特别的损失函数形式的模型,TFGAN库还提供了其他的API。其中类似gan_model的还有infogan_model、acgan_model、cyclegan_model和stargan_model,类似gan_loss的还有cyclegan_model和stargan_loss。如果要创建一个普通的GAN模型,使用tfgan.gan_model和tfgan.gan_loss即可;如果要创建INFOGAN,使用tfgan.infogan_model和tfgan.gan_loss;如果要创建ACGAN,使用tfgan.acgan_model和tfgan.gan_loss;如果要创建CycleGAN,使用tfgan.cyclegan_model和tfgan.cyclegan_loss;如果要创建StarGAN,使用tfgan.stargan_model和tfgan.stargan_loss。
TODO:TFGAN的定义loss的py文件中还有一个tfgan.losses.combine_adversarial_loss,暂时还不知道有什么作用,待后面补充。
网友评论