美文网首页
使用Keras构建GAN

使用Keras构建GAN

作者: 一位学有余力的同学 | 来源:发表于2021-03-18 13:28 被阅读0次

    在开始之前请先了解GAN的原理,有很多博主讲的都很好,在这里我就不再过多讲解,视频推荐台大李宏毅老师的课程。

    GAN共包含两个主要结构generator和discriminator。generator负责生成假的数据来“欺骗”discriminator,discriminator负责判断输入的数据是否为generator生成的,二者互相迭代,最终实现generator生成能以假乱真的数据。以下以Mnist数据集为例,使用GAN来产生手写数字。

    构建网络模型

    1.generator

    神经网络模型有输出就有输入,我们要想得到假的生成数据,就要给模型一个输入,这里采用形状为[100,]的向量作为输入,输出是形状为[28,28,1]的矩阵。

        def build_generator(self):
            # input shape = [100,]
            # output shape = [np.prod(self.img_shape)]
            
            model = Sequential()
    
            model.add(Dense(256, input_dim=self.latent_dim))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            model.add(Dense(1024))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            # image_shape = [28,28,1]
            model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
            model.add(Reshape(self.img_shape))
    
            model.summary()
    
            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
    
            return Model(noise, img)
    

    2.discriminator

    判别器的输入为生成的假的图片,形状为[28,28,1],输出为判别器给出的validity,区间为[0,1],数越大表面判别器任务输入是真实数据的可能性越大,反之则认为输入数据是真实数据的可能性越小。

        def build_discriminator(self):
    
            model = Sequential()
    
            model.add(Flatten(input_shape=self.img_shape))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(256))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(1, activation='sigmoid'))
            model.summary()
    
            img = Input(shape=self.img_shape)
            validity = model(img)
    
            return Model(img, validity)
    

    3.构建完整模型

            optimizer = Adam(0.0002, 0.5)
    
            # 构建和编译判别器
            self.discriminator = self.build_discriminator()
            self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    
            # 构建生成器
            self.generator = self.build_generator()
    
            # 输入噪声给生成器,并产生假的图片
            z = Input(shape=(self.latent_dim,))
            img = self.generator(z)
    
            # 冻结判别器
            self.discriminator.trainable = False
    
            # 将假的图片输入给判别器
            validity = self.discriminator(img)
    
            # 将生成器和判别器合二为一
            self.combined = Model(z, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    

    4.训练策略

    1. 先训练判别器,将真实图片和生成器生成的假的图片(真实图片标签为1,生成图片标签为0)分别输入到generator中,计算两个数据集损失的平均值,这一步的目的是使判别去学会区分哪一种是真实图片,哪一种是生成的图片

    2. 然后训练生成器,但实际上训练的是刚刚构建的完整的模型combined,但是由于将discriminator冻结了,所以只有generator参与训练。然后将预测结果与1对比,如果越接近1说明生成器已经生成了能欺骗discriminator的图片,通过优化loss使generator产生的图片越接近真实图片

        def train(self, epochs, batch_size=128, sample_interval=50):
    
            # Load the dataset
            (X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
            # X_train.shape = (60000, 28, 28)
            
            # Rescale -1 to 1 归一化
            X_train = X_train / 127.5 - 1.
            X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)
    
            # Adversarial ground truths
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))
    
            for epoch in range(epochs):
    
                # ---------------------
                #  Train Discriminator
                # ---------------------
    
                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
                imgs = X_train[idx] # 随机取出batchsize个图片
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)
    
                # Generate a batch of new images
                gen_imgs = self.generator.predict(noise)
    
                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 输入的是真实图片,valid都是1
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # ---------------------
                #  Train Generator
                # ---------------------
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    
                # Train the generator (to have the discriminator label samples as valid)
                g_loss = self.combined.train_on_batch(noise, valid)
    
                # Plot the progress
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    

    5.GAN网络结构

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    flatten_1 (Flatten)          (None, 784)               0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 512)               401920    
    _________________________________________________________________
    leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
    _________________________________________________________________
    dense_2 (Dense)              (None, 256)               131328    
    _________________________________________________________________
    leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
    _________________________________________________________________
    dense_3 (Dense)              (None, 1)                 257       
    =================================================================
    Total params: 533,505
    Trainable params: 533,505
    Non-trainable params: 0
    _________________________________________________________________
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_4 (Dense)              (None, 256)               25856     
    _________________________________________________________________
    leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 256)               1024      
    _________________________________________________________________
    dense_5 (Dense)              (None, 512)               131584    
    _________________________________________________________________
    leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 512)               2048      
    _________________________________________________________________
    dense_6 (Dense)              (None, 1024)              525312    
    _________________________________________________________________
    leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0         
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 1024)              4096      
    _________________________________________________________________
    dense_7 (Dense)              (None, 784)               803600    
    _________________________________________________________________
    reshape_1 (Reshape)          (None, 28, 28, 1)         0         
    =================================================================
    Total params: 1,493,520
    Trainable params: 1,489,936
    Non-trainable params: 3,584
    

    完整代码

    from keras.datasets import mnist
    from keras.layers import Input, Dense, Reshape, Flatten, Dropout
    from keras.layers import BatchNormalization, Activation, ZeroPadding2D
    from keras.layers.advanced_activations import LeakyReLU
    from keras.layers.convolutional import UpSampling2D, Conv2D
    from keras.models import Sequential, Model
    from keras.optimizers import Adam
    
    import matplotlib.pyplot as plt
    
    import sys
    import os
    
    import numpy as np
    
    class GAN():
        def __init__(self):
            self.img_rows = 28
            self.img_cols = 28
            self.channels = 1
            self.img_shape = (self.img_rows, self.img_cols, self.channels)
            self.latent_dim = 100
    
            optimizer = Adam(0.0002, 0.5)
    
            # Build and compile the discriminator
            self.discriminator = self.build_discriminator()
            self.discriminator.compile(loss='binary_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
    
            # Build the generator
            self.generator = self.build_generator()
    
            # The generator takes noise as input and generates imgs
            z = Input(shape=(self.latent_dim,))
            img = self.generator(z)
    
            # For the combined model we will only train the generator
            self.discriminator.trainable = False
    
            # The discriminator takes generated images as input and determines validity
            validity = self.discriminator(img)
    
            # The combined model  (stacked generator and discriminator)
            # Trains the generator to fool the discriminator
            self.combined = Model(z, validity)
            self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    
        # 构建生成器
        def build_generator(self):
            # input shape = [100,]
            # output shape = [np.prod(self.img_shape)]
            
            model = Sequential()
    
            model.add(Dense(256, input_dim=self.latent_dim))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            model.add(Dense(1024))
            model.add(LeakyReLU(alpha=0.2))
            model.add(BatchNormalization(momentum=0.8))
            # image_shape = [28,28,1]
            model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
            model.add(Reshape(self.img_shape))
    
            model.summary()
    
            noise = Input(shape=(self.latent_dim,))
            img = model(noise)
    
            return Model(noise, img)
        
        # 构建判别器
        def build_discriminator(self):
    
            model = Sequential()
    
            model.add(Flatten(input_shape=self.img_shape))
            model.add(Dense(512))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(256))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(1, activation='sigmoid'))
            model.summary()
    
            img = Input(shape=self.img_shape)
            validity = model(img)
    
            return Model(img, validity)
    
        def train(self, epochs, batch_size=128, sample_interval=50):
    
            # Load the dataset
            (X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
            # X_train.shape = (60000, 28, 28)
            
            # Rescale -1 to 1 归一化
            X_train = X_train / 127.5 - 1.
            X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)
    
            # Adversarial ground truths
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))
    
            for epoch in range(epochs):
    
                # ---------------------
                #  Train Discriminator
                # ---------------------
    
                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
                imgs = X_train[idx] # 随机取出batchsize个图片
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)
    
                # Generate a batch of new images
                gen_imgs = self.generator.predict(noise)
    
                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 输入的是真实图片,valid都是1
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
                # ---------------------
                #  Train Generator
                # ---------------------
    
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
    
                # Train the generator (to have the discriminator label samples as valid)
                g_loss = self.combined.train_on_batch(noise, valid)
    
                # Plot the progress
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    
                # If at save interval => save generated image samples
                if epoch % sample_interval == 0:
                    self.sample_images(epoch)
    
        def sample_images(self, epoch):
            r, c = 5, 5
            noise = np.random.normal(0, 1, (r * c, self.latent_dim))
            gen_imgs = self.generator.predict(noise)
    
            # Rescale images 0 - 1
            gen_imgs = 0.5 * gen_imgs + 0.5
    
            fig, axs = plt.subplots(r, c)
            cnt = 0
            for i in range(r):
                for j in range(c):
                    axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("images/%d.png" % epoch)
            plt.close()
    
    if __name__ == '__main__':
        if not os.path.exists("./images"):
            os.makedirs("./images")
        gan = GAN()
        gan.train(epochs=30000, batch_size=32, sample_interval=200) #sample_interval => 采样间隔
    

    下面为分别训练第0,10000,20000和29800个epoch时generator产生的图像:


    0.png 10000.png 20000.png 29800.png

    相关文章

      网友评论

          本文标题:使用Keras构建GAN

          本文链接:https://www.haomeiwen.com/subject/bjuucltx.html