前段时间看了SRGAN,他的目的主要是做图片分辨率提升,即提高图片的分辨率作用,其实该网路发表也有很长时间了(2016年)也算是gan网络入门必备的模型。觉得模型其实也蛮简单的,今天看了一下模型代码感觉也很好理解,下面说一下模型及代码原理吧~
不出意外先给出github代码地址以及paper地址:Code | Paper
一、 模型结构介绍
![](https://img.haomeiwen.com/i5356150/699431956f0bca39.png)
上面两张图基本上描述了生产网络以及判别网络的结构,这里需要补充的是此网络的输入分别是一张图片的高清图片以及非高清图片。
1、生成网络模型结构及代码
生成网络主要通过将输入图片进行下采样卷积在进行上采样得到我们生成的高清图像,这里我们生成的高清图像为原来图像的16倍(宽高都设置成为原来的4倍),同时在下采样的过程中用到了残差网络部分技巧。我们结合一下keras代码更能清楚的了解模型结构。
def build_generator(self):
def residual_block(layer_input, filters):
"""Residual block described in paper"""
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = Activation('relu')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
def deconv2d(layer_input):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
# Low resolution image input
img_lr = Input(shape=self.lr_shape)
# Pre-residual block
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
# Propogate through residual blocks
r = residual_block(c1, self.gf)
for _ in range(self.n_residual_blocks - 1):
r = residual_block(r, self.gf)
# Post-residual block
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
# Generate high resolution output
gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
2. 判别模型结构及代码
![](https://img.haomeiwen.com/i5356150/4dfe7151157c3072.png)
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
# Input img
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df*2)
d4 = d_block(d3, self.df*2, strides=2)
d5 = d_block(d4, self.df*4)
d6 = d_block(d5, self.df*4, strides=2)
d7 = d_block(d6, self.df*8)
d8 = d_block(d7, self.df*8, strides=2)
d9 = Dense(self.df*16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
3. 模型训练流程
判别模型训练 → 生成模型训练
1)判别模型
# ----------------------
# Train Discriminator
# ----------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# From low res. image generate high res. version
fake_hr = self.generator.predict(imgs_lr)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
先将我们的图片输入到生成网络中去得到fake_hr
即生成高清图片。我们的label为valid
(真实高清图的label)以及fake
(生成高清图片label)。 同样在生成label的时候我们还是以patchGAN的形式生成, patchGAN了解我在CycleGAN有提到过。
disc_patch
大小如下所示:
patch = int(self.hr_height / 2**4)
self.disc_patch = (patch, patch, 1)
这里是将图片分割成16份,其实可以理解为一份类似于一个感受野,然后对每一份区域进行判别,判断它是真实还是生成的高清图像。
这边的loss使用的是MSEloss。
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
之后分别计算真实高清图片的loss与生成的高清图片loss求平均得到我们的判别模型 loss。
2)生成模型
# ------------------
# Train Generator
# ------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# The generators want the discriminators to label the generated images as real
valid = np.ones((batch_size,) + self.disc_patch)
# Extract ground truth image features using pre-trained VGG19 model
image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
首先我们来看下其label是什么样子
# Generate high res. version from low res.
fake_hr = self.generator(img_lr)
# Extract image features of the generated img
fake_features = self.vgg(fake_hr)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# Discriminator determines validity of generated high res. images
validity = self.discriminator(fake_hr)
self.combined = Model([img_lr, img_hr], [validity, fake_features])
这样就很好理解了,在进行生成网络训练的时候我们先将我们的判别网络停止训练(self.discriminator.trainable = False
), 首先将我们的低分辨率的图通过生成网络得到fake_feature
, 再将我们的高清原图image_hr
和fake_feature
分别通过vgg特征提取网络在进行特征MSE
求loss即可。是不是听上去so easy~ 代码也很简单。当然还有一处也是需要求loss的那就是validity
, 这里我们的判别器识别出来的validity
也要与我们的valid
进行loss计算, 这也是很好理解,让模型更向着判别器判别为真高清网络的趋势去学习,所以这里有两个loss计算要留意了哈~
二、 训练完整代码
"""
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
import keras_contrib.layers.normalization.instancenormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os
import keras.backend as K
class SRGAN():
def __init__(self):
# Input shape
self.channels = 3
self.lr_height = 64 # Low resolution height
self.lr_width = 64 # Low resolution width
self.lr_shape = (self.lr_height, self.lr_width, self.channels)
self.hr_height = self.lr_height*4 # High resolution height
self.hr_width = self.lr_width*4 # High resolution width
self.hr_shape = (self.hr_height, self.hr_width, self.channels)
# Number of residual blocks in the generator
self.n_residual_blocks = 16
optimizer = Adam(0.0002, 0.5)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
self.vgg = self.build_vgg()
self.vgg.trainable = False
self.vgg.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
# Configure data loader
self.dataset_name = 'img_align_celeba'
self.data_loader = DataLoader(dataset_name=self.dataset_name,
img_res=(self.hr_height, self.hr_width))
# Calculate output shape of D (PatchGAN)
patch = int(self.hr_height / 2**4)
self.disc_patch = (patch, patch, 1)
# Number of filters in the first layer of G and D
self.gf = 64
self.df = 64
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# High res. and low res. images
img_hr = Input(shape=self.hr_shape)
img_lr = Input(shape=self.lr_shape)
# Generate high res. version from low res.
fake_hr = self.generator(img_lr)
# Extract image features of the generated img
fake_features = self.vgg(fake_hr)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# Discriminator determines validity of generated high res. images
validity = self.discriminator(fake_hr)
self.combined = Model([img_lr, img_hr], [validity, fake_features])
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
def build_vgg(self):
"""
Builds a pre-trained VGG19 model that outputs image features extracted at the
third block of the model
"""
vgg = VGG19(weights="imagenet")
# Set outputs to outputs of last conv. layer in block 3
# See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=self.hr_shape)
# Extract image features
img_features = vgg(img)
return Model(img, img_features)
def build_generator(self):
def residual_block(layer_input, filters):
"""Residual block described in paper"""
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = Activation('relu')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
def deconv2d(layer_input):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
# Low resolution image input
img_lr = Input(shape=self.lr_shape)
# Pre-residual block
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
# Propogate through residual blocks
r = residual_block(c1, self.gf)
for _ in range(self.n_residual_blocks - 1):
r = residual_block(r, self.gf)
# Post-residual block
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
# Generate high resolution output
gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
# Input img
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df*2)
d4 = d_block(d3, self.df*2, strides=2)
d5 = d_block(d4, self.df*4)
d6 = d_block(d5, self.df*4, strides=2)
d7 = d_block(d6, self.df*8)
d8 = d_block(d7, self.df*8, strides=2)
d9 = Dense(self.df*16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
def train(self, epochs, batch_size=1, sample_interval=50):
start_time = datetime.datetime.now()
for epoch in range(epochs):
# ----------------------
# Train Discriminator
# ----------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# From low res. image generate high res. version
fake_hr = self.generator.predict(imgs_lr)
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ------------------
# Train Generator
# ------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# The generators want the discriminators to label the generated images as real
valid = np.ones((batch_size,) + self.disc_patch)
# Extract ground truth image features using pre-trained VGG19 model
image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
elapsed_time = datetime.datetime.now() - start_time
# Plot the progress
print ("%d time: %s" % (epoch, elapsed_time))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
r, c = 2, 2
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)
fake_hr = self.generator.predict(imgs_lr)
# Rescale images 0 - 1
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
# Save generated images and the high resolution originals
titles = ['Generated', 'Original']
fig, axs = plt.subplots(r, c)
cnt = 0
for row in range(r):
for col, image in enumerate([fake_hr, imgs_hr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
cnt += 1
fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))
plt.close()
# Save low resolution images for comparison
for i in range(r):
fig = plt.figure()
plt.imshow(imgs_lr[i])
fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))
plt.close()
if __name__ == '__main__':
gan = SRGAN()
gan.train(epochs=30000, batch_size=1, sample_interval=50)
网友评论