DCGAN简介
DCGAN在GAN的基础上优化了网络结构,加入了 conv
,batch_norm
等层,使得网络更容易训练,网络结构如下:
注意:本图只是示例,与下面实际网络参数不对应。
Tensorflow实现DCGAN
from __future__ import division, print_function, absolute_import
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
导入数据集
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz
设置参数
# 训练参数
num_steps = 10000 # 总迭代次数
batch_size = 128 # 批量大小
lr_generator = 0.002 # 生成器学习率
lr_discriminator = 0.002 # 判别器学习率
# 网络参数
image_dim = 784 # 28*28 pixels * 1 channel
noise_dim = 100 # Noise data points
构建DCGAN网络
# 构建网络
# 网络输入
noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim]) # 生成器输入 噪声 batch*100,none后面被赋值batch
real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) # 判别器输入 真实图像 batch*28*28*1
# A boolean to indicate batch normalization if it is training or inference time
# 判断是否在训练
is_training = tf.placeholder(tf.bool)
# 定义激活函数 LeakyReLU,在判别器网络中用
# LeakyReLU 是 ReLU 的变种 [^1]
def leakyrelu(x, alpha=0.2):
return 0.5 * (1 + alpha) * x + 0.5 * (1 - alpha) * abs(x)
# 定义生成器网络
# 输入:噪声 输出:图像
# 训练时,才使用batch_normalization
def generator(x, reuse=False):
with tf.variable_scope('Generator', reuse=reuse):
# 第一层为全连接层,含神经元个数为7*7*128,输入是噪声batch*100
x = tf.layers.dense(x, units=7 * 7 * 128)
# tf.layers.batch_normalization() 的第二个参数axis表示在哪一个维度做normalize,通常数据排布顺序为(batch, height, width, channels),固默认为-1
# 全连接层channel=1,所以是对所有数据做normalize
x = tf.layers.batch_normalization(x, training=is_training)
# 激活函数 rule
x = tf.nn.relu(x)
# Reshape为4维: (batch, height, width, channels),这里是 (batch, 7, 7, 128)
x = tf.reshape(x, shape=[-1, 7, 7, 128])
# 反卷积层1
# 卷积核大小5*5*128,64个,步长2(tf.layers.conv2d_transpose函数前几个参数为input,filters(输出feature map通道数),kernel_size, strides,padding)
# 输入x shape:(batch,7,7,128), 输出image shape: (batch, 14, 14, 64)
x = tf.layers.conv2d_transpose(x, 64, 5, strides=2, padding='same')
# batch normalization,在channel维度上做normalize
x = tf.layers.batch_normalization(x, training=is_training)
# 激活函数 relu
x = tf.nn.relu(x)
# 反卷积层2
# 卷积核大小5*5*128,1个,步长2
# 输入x shape:(batch,14,14,64), 输出image shape: (batch, 28, 28, 1)
x = tf.layers.conv2d_transpose(x, 1, 5, strides=2, padding='same')
# 激活函数 tanh
# Apply tanh for better stability - clip values to [-1, 1].
x = tf.nn.tanh(x)
return x
# 定义判别器网络
# 输入:图像, 输出: 预测结果(Real/Fake Image)
# 同样训练时,才使用batch_normalization
def discriminator(x, reuse=False):
with tf.variable_scope('Discriminator', reuse=reuse):
# 卷积层1,输入x,卷积核大小5x5,64个,步长2
x = tf.layers.conv2d(x, 64, 5, strides=2, padding='same')
x = tf.layers.batch_normalization(x, training=is_training)
# 激活函数 leakyrelu
x = leakyrelu(x)
# 卷积层2,输入第一个卷积层的输出,卷积核大小5x5,128个,步长2
x = tf.layers.conv2d(x, 128, 5, strides=2, padding='same')
x = tf.layers.batch_normalization(x, training=is_training)
# 激活函数 leakyrelu
x = leakyrelu(x)
# 展平
x = tf.reshape(x, shape=[-1, 7*7*128])
# 全连接层,含1024个神经元
x = tf.layers.dense(x, 1024)
x = tf.layers.batch_normalization(x, training=is_training)
# 激活函数 leakyrelu
x = leakyrelu(x)
# 输出2个类别: Real and Fake images
x = tf.layers.dense(x, 2)
return x
# 构建生成器
gen_sample = generator(noise_input)
# 构建两个判别器(一个是真实图像输入,一个是生成图像)
disc_real = discriminator(real_image_input)
disc_fake = discriminator(gen_sample, reuse=True)
# Build the stacked generator/discriminator
# 用于计算生成器的损失
stacked_gan = discriminator(gen_sample, reuse=True)
# 创建损失函数,交叉熵
# 真实图像,标签1
disc_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=disc_real, labels=tf.ones([batch_size], dtype=tf.int32)))
# 生成图像,标签0
disc_loss_fake = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=disc_fake, labels=tf.zeros([batch_size], dtype=tf.int32)))
# 判别器损失函数是两者之和
disc_loss = disc_loss_real + disc_loss_fake
# 生成器损失函数 (生成器试图骗过判别器,因此这里标签是1)
gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=stacked_gan, labels=tf.ones([batch_size], dtype=tf.int32)))
# 创建优化器(采用Adam方法)
optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_generator, beta1=0.5, beta2=0.999)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=lr_discriminator, beta1=0.5, beta2=0.999)
# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# 生成网络的变量
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator') # tf.get_collection:从一个结合中取出全部变量,是一个列表
# 判别器网络的变量
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')
# 创建训练操作
# TensorFlow UPDATE_OPS collection holds all batch norm operation to update the moving mean/stddev
gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Generator')
# `control_dependencies` ensure that the `gen_update_ops` will be run before the `minimize` op (backprop)
with tf.control_dependencies(gen_update_ops):
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
disc_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Discriminator')
with tf.control_dependencies(disc_update_ops):
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
# 变量全局初始化
init = tf.global_variables_initializer()
训练
# Start Training
# Start a new TF session
sess = tf.Session()
# Run the initializer
sess.run(init)
# Training
for i in range(1, num_steps+1):
# Prepare Input Data
# Get the next batch of MNIST data (only images are needed, not labels)
batch_x, _ = mnist.train.next_batch(batch_size)
batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
# Rescale to [-1, 1], the input range of the discriminator
batch_x = batch_x * 2. - 1.
# Discriminator Training
# Generate noise to feed to the generator
z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
_, dl = sess.run([train_disc, disc_loss], feed_dict={real_image_input: batch_x, noise_input: z, is_training:True})
# Generator Training
# Generate noise to feed to the generator
z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
_, gl = sess.run([train_gen, gen_loss], feed_dict={noise_input: z, is_training:True})
if i % 500 == 0 or i == 1:
print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
Step 1: Generator Loss: 4.064141, Discriminator Loss: 1.679586
Step 500: Generator Loss: 1.472707, Discriminator Loss: 0.974612
Step 1000: Generator Loss: 1.918907, Discriminator Loss: 0.964812
Step 1500: Generator Loss: 2.567637, Discriminator Loss: 0.717904
Step 2000: Generator Loss: 2.398796, Discriminator Loss: 0.512406
Step 2500: Generator Loss: 3.057401, Discriminator Loss: 1.235215
Step 3000: Generator Loss: 2.620444, Discriminator Loss: 0.539795
Step 3500: Generator Loss: 3.193395, Discriminator Loss: 0.265896
Step 4000: Generator Loss: 5.071162, Discriminator Loss: 0.409445
Step 4500: Generator Loss: 5.213869, Discriminator Loss: 0.203033
Step 5000: Generator Loss: 6.087250, Discriminator Loss: 0.350634
Step 5500: Generator Loss: 5.467363, Discriminator Loss: 0.424895
Step 6000: Generator Loss: 4.910432, Discriminator Loss: 0.196554
Step 6500: Generator Loss: 3.230242, Discriminator Loss: 0.268745
Step 7000: Generator Loss: 4.777361, Discriminator Loss: 0.676658
Step 7500: Generator Loss: 4.165446, Discriminator Loss: 0.150221
Step 8000: Generator Loss: 5.681596, Discriminator Loss: 0.108955
Step 8500: Generator Loss: 6.023059, Discriminator Loss: 0.114312
Step 9000: Generator Loss: 4.660669, Discriminator Loss: 0.182506
Step 9500: Generator Loss: 4.492438, Discriminator Loss: 0.411817
Step 10000: Generator Loss: 5.906080, Discriminator Loss: 0.088082
测试
# Testing
# Generate images from noise, using the generator network.
n = 6
canvas = np.empty((28 * n, 28 * n))
for i in range(n):
# Noise input.
z = np.random.uniform(-1., 1., size=[n, noise_dim])
# Generate image from noise.
g = sess.run(gen_sample, feed_dict={noise_input: z, is_training:False})
# Rescale values to the original [0, 1] (from tanh -> [-1, 1])
g = (g + 1.) / 2.
# Reverse colours for better display
g = -1 * (g - 1)
for j in range(n):
# Draw the generated digits
canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])
plt.figure(figsize=(n, n))
plt.imshow(canvas, origin="upper", cmap="gray")
plt.show()
image
网友评论