美文网首页
keras 实现 GAN

keras 实现 GAN

作者: 走在成长的道路上 | 来源:发表于2018-11-27 11:44 被阅读0次

通过 Keras 实现 GAN ,其主要过程如下:

GAN训练过程分析

正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

K.set_image_dim_ordering('th')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

# 创建 Generator 操作序列
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
# 编译 Generator 操作序列
generator.compile(loss='binary_crossentropy', optimizer=adam)

# 创建 Discriminator 操作序列
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
# 编译 Discriminator 操作序列
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
# 设置 Discriminator 不可训练
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
# 编译 GAN 网络流程
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('models/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batchCount)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            # 用 Generator 生成假数据
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            # 将假数据与真实数据进行混合在一起
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            # 标记所有数据都是假数据
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            # 按真实数据比例,标记前半数据为 0.9 的真实度
            yDis[:batchSize] = 0.9

            # Train discriminator
            # 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            # 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
            # 先生成相应 batchSize 样本 noise 数据
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            # 生成相应的 Discriminator 输出结果
            yGen = np.ones(batchSize)
            # 将 Discriminator 设置为不可训练的状态
            discriminator.trainable = False
           # 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 20 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

if __name__ == '__main__':
    train(200, 128)

上述,通过 Sequential 创建 GeneratorDiscriminator 网络,然后通过 Model 将其进行结合在一起,从而构成了 GAN 网络的全貌。

参考

GAN-in-keras-on-mnist

相关文章

网友评论

      本文标题:keras 实现 GAN

      本文链接:https://www.haomeiwen.com/subject/mlqgqqtx.html