美文网首页
利用GAN生成CIFAR10图片

利用GAN生成CIFAR10图片

作者: dataengineer | 来源:发表于2020-03-07 22:27 被阅读0次

要点

  • GAN (Generative Adversarial Networks,生成对抗网络)是一种可以训练生成模型(generative model)的架构
  • GAN由判别器(discriminator)和生成器(generator)两部分组成,判别器用于识别生成器结果的真实性,即生成的图片是真实的还是计算机生成的,而生成器则根据判别器的识别结果努力生成虚假(fake)但看似真实(plausible)的图片以欺骗判别器。在训练过程中,生成器会根据判别器的性能不断更新自己的模型权重,判别器和生成器不断地进行对抗,犹如一个动态的博弈过程,并最终趋于平衡,即判别器无法正确识别图片的真假,生成器生成的图片接近于真实的图片
  • 模型架构上,GAN是生成器和判别器的堆叠,生成器以隐空间(latent space)中的随机点为输入,输出结果为图片样本,而判别器以生成器输出的图片样本为输入,其输出结果是图片的真假,判别器的输出结果将用于更新生成器的模型权重
  • 当判别器被看作一个独立模型时,可单独对其进行训练,因为判别器只关心图片样本的真假。当判别器和生成器被看作一个整体时,判别器的各层在训练过程中需保持冻结,以避免被虚假样本过度训练。以上两点看似矛盾,实际上可通过tf.keras API巧妙地实现:一个模型可被训练还是已被冻结,这个属性只有模型被编译后才能影响模型,具体细节见代码部分
  • 当判别器和生成器被看作一个整体时,生成器输出的图片样本都要标记为“真”即class = 1。这样做的原因是,当判别器认为图片样本为“假”(即class = 0)或者为“真”的概率较低时,后向传播过程会视其为巨大误差,并据此更新生成器的模型权重以纠正这一误差,也就是让生成器更好地生成虚假样本
  • 判别器没有pooling layer,而是采用2x2的stride,其效果和pooling layer类似
  • 隐空间是一个向量空间(呈高斯分布),其本身没有意义,但是生成器可以赋予隐空间意义。经过训练后,隐向量空间可看作是生成图片的压缩表示
  • deconvolution有两种实现形式,第一种是先上采样再卷积(UpSampling2D→Conv2D),第二种是直接采用Conv2DTranspose,本文代码部分选择第二种形式
  • 将Conv2DTranspose层中的stride配置为2x2,可使输入的feature map的面积增大4倍,同时,将kernel size的大小设置为stride的倍数(比如4x4)还可避免出现checkerboard pattern
  • LeakyReLU的slope建议设为0.2
  • CIFAR-10数据集有60000张32x32彩色图片,包含10个分类,如青蛙、鸟、猫、船、飞机等等,由CIFAR(Canadian Institute For Advanced Research)开发,图片尺寸较小,主要用于计算机视觉研究

代码部分

# load required libraries

import tensorflow as tf
import matplotlib.pyplot as plt
from keras.utils.vis_utils import plot_model
import numpy as np
tf.__version__

'2.0.0'

# load CIFAR10 datasets

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# print the shapes of training and test data

x_train.shape, y_train.shape, x_test.shape, y_test.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

# plot training data

fig = plt.gcf()
fig.set_size_inches(10,10)
for i in range(49):
    plt.subplot(7,7,1+i)
    plt.imshow(x_train[i])
image.png
# define the standalone discriminator model

def discriminator_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters = 64, kernel_size = (3,3), padding = 'same', input_shape = (32,32,3)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 128, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Conv2D(filters = 256, kernel_size = (3,3), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
    ])
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5), metrics = ['accuracy'])
    
    return model

提示:判别器没有pooling layer,而是采用2*2的stride,其效果和pooling layer类似

# show the summary and graph of the discriminator model 

model = discriminator_model()

model.summary()
image.png
# convert unsigned int to float32

x_train = x_train.astype('float32')
x_train = (x_train - 127.5)/127.5

提示:生成器以tanh为激活函数,其生成的像素值范围为[-1,1],因此,真实图片的像素值范围也应从[0,255]标准化为[-1,1]

# generate points in latent space as the inputs of the generator

def generate_latent_points(latent_dim,n_samples):
    x_input = np.random.randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples,latent_dim)
    return x_input    
# randomly select n real samples

def generate_real_samples(dataset, n_samples):
    # define random instances
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    x = dataset[ix]
    # generate class label (label = 1)
    y = np.ones((n_samples,1))
    return x,y

# generate n fake samples with class label

def generate_fake_samples(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    x = g_model.predict(x_input)
    # generate class label (label = 0)
    y = np.zeros((n_samples,1))
    return x,y
# define the standalone generator model

def generator_model(latent_dim):
    n_nodes = 256*4*4
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units = n_nodes, input_dim = latent_dim),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        tf.keras.layers.Reshape((4,4,256)),
        # upsample to 8*8
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 16*16
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # upsample to 32*32
        tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4,4), padding = 'same', strides = (2,2)),
        tf.keras.layers.LeakyReLU(alpha = 0.2),
        # output layer
        tf.keras.layers.Conv2D(filters = 3, kernel_size = (3,3), activation = 'tanh', padding = 'same')      
    ])
    return model
# show the summary and graph of the generator model

model = generator_model(100)

model.summary()
image.png
# define gan model (only generator model can be updated)

def gan_model(g_model, d_model):
    # freeze discriminator model
    d_model.trainable = False
    
    model = tf.keras.models.Sequential([
        g_model,
        d_model
    ])
    
    model.compile(loss = 'binary_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5))
    
    return model
# show the summary and graph of the gan model

latent_dim = 100

g_model = generator_model(latent_dim)

d_model = discriminator_model()

gan_model = gan_model(g_model,d_model)

gan_model.summary()
image.png
# show and save the plots of generated images

def save_plot(examples, epoch, n = 7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1)/2.0
    # make plot
    for i in range(n*n):
        plt.subplot(n,n,i+1)
        plt.imshow(examples[i])
    
    # save plots
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(filename)

# evaluate discriminator model performance, display generated images, save generator model

def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples = 150):
    # prepare real samples
    x_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real samples
    _, acc_real = d_model.evaluate(x_real, y_real, verbose = 0)
    # prepare fake samples
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake samples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose = 0)
    # display discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    
    # show and save the plots of generated images
    save_plot(x_fake, epoch)
    
    # save generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch+1)
    g_model.save(filename)
# train gan model

def train_gan(g_model, d_model, gan_model, dataset, latent_dim, n_epochs = 20, n_batch = 128):
    bat_per_epoch = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches
        for j in range(bat_per_epoch):
            # randomly select n real samples
            x_real, y_real = generate_real_samples(dataset, half_batch)
            # update standalone discriminator model
            d_loss1, _ = d_model.train_on_batch(x_real, y_real)
            # generate fake samples
            x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update standalone discriminator model again
            d_loss2, _ = d_model.train_on_batch(x_fake, y_fake)
            # generate points in latent space as the inputs of generator model
            x_gan = generate_latent_points(latent_dim, n_batch)
            # generate class label for fake samples (label = 1)
            y_gan = np.ones((n_batch,1))
            # update the generator model with discriminator model errors
            g_loss = gan_model.train_on_batch(x_gan, y_gan)
            # display the loss
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epoch, d_loss1, d_loss2, g_loss))
        
        # evaluate model performance every 5 epochs  
        if (i + 1)%5 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

提示:GAN的目标是让生成器生成“看似真实”的图片,然而这些图片的质量高低无法通过客观的误差指标来体现,只能由程序员进行人工判读。换言之,即程序员不检查图片的质量,就不知道什么时候该停止训练。例如,某一个epoch结束后,生成器输出的图片质量很高,此时若不停止训练,之后生成的图片质量会发生波动(GAN的对抗性导致每一个batch后生成器都会发生变化),也可能提升,也可能降低。因此,在实际训练过程中,程序员要周期性地评估判别器分辨真假图片的能力(即分类精度),也要周期性地生成图片并进行人工判读,还要周期性地保存生成器模型

# train gan model

train_gan(g_model, d_model, gan_model, x_train, latent_dim)

1, 1/390, d1=0.376, d2=0.280 g=1.740
1, 2/390, d1=0.351, d2=0.322 g=1.679
1, 3/390, d1=0.274, d2=0.299 g=1.866
1, 4/390, d1=0.301, d2=0.272 g=2.027
1, 5/390, d1=0.257, d2=0.230 g=2.256
1, 6/390, d1=0.204, d2=0.186 g=2.558
……
……
20, 387/390, d1=0.724, d2=0.636 g=0.865
20, 388/390, d1=0.665, d2=0.623 g=0.837
20, 389/390, d1=0.678, d2=0.717 g=0.867
20, 390/390, d1=0.718, d2=0.606 g=0.960
Accuracy real: 51%, fake: 89%

image.png
提示:本代码以20个epochs为示例,每5个epochs评估一次模型性能,20个epochs共评估模型性能4次,生成图片4副,保存模型4个。接下来,就可以用性能最好的生成器生成图片了。
# generate images with final generator model

model = tf.keras.models.load_model('generator_model_020.h5') # load model saved after 20 epochs

latent_points = generate_latent_points(100,100) # generate points in latent space

X = model.predict(latent_points) # generate images

X = (X + 1)/2.0 # scale the range from [-1,1] to [0,1]
# plot the images

fig = plt.gcf()
fig.set_size_inches(20,20)
for i in range(100):
    plt.subplot(10,10,1+i)
    plt.imshow(X[i])
image.png

相关文章

  • 利用GAN生成CIFAR10图片

    要点 GAN (Generative Adversarial Networks,生成对抗网络)是一种可以训练生成模...

  • 2020机器学习GAN(CGAN) (9)

    CGAN(Conditonal Generation GAN) 什么是 CGAN 呢?之前我们 GAN 生成图片是...

  • GAN简单的实现01

    目的: 得到一个生成模型G, 可以用来生成图片。 简介 生成式对抗网络(GAN, Generative Adver...

  • 关于故障诊断应用深度学习的一些想法

    利用深度学习模型进行端到端的故障诊断 利用深度学习的模型进行寿命预测 利用GAN(生成对抗网络)进行数据生成 进行...

  • 2020机器学习GAN(2)

    今天目标是介绍一下 GAN 是如何做到输出图片的这个样任务。在 GAN 中分别有两个阶段,在生成阶段就是固定生成器...

  • LSGAN:最小二乘生成对抗网络

    解决问题:解决传统 GAN 生成图片质量不高,训练不稳定的问题。 做法:将传统 GAN 的交叉熵损失函数换成最小二...

  • 使用MatConvNet搭建GAN网络

    概述 该代码使用MatConvNet在matlab上搭建GAN网络,用来生成手写数字图片。 MatConvNet是...

  • Gan网络

    本文所有图片来自北京睿邦科技有限公司Gan网络教学课程教师王昊然。 判别模型与生成模型 在说Gan网络之前有必要先...

  • 2020机器学习自编码器(autoencoder)(上)

    最近分享了 GAN 生成图片时候,在说到 Generator 如何自力更生不依赖 Discriminator 来生...

  • GAN- 生成式对抗网络

    到底什么是生成式对抗网络GAN? 针对数据量缺乏的场景,生成模型则可以帮助生成数据,提高数据数量,从而利用半监督学...

网友评论

      本文标题:利用GAN生成CIFAR10图片

      本文链接:https://www.haomeiwen.com/subject/fnhhdhtx.html