前段时间测试了下dcgan生成mnist,结果发现最后的生成的图片并不是那么好,收敛较慢,生成的图片中有很多个零等缺点,这里测试wgan的效果。
代码如下:
import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers, losses
from PIL import Image
import numpy as np
import os
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
# img = img.astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
# concat image into a row
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
# concat image row to final_image
if (b + 1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
# reset single row
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
Image.fromarray(final_image).save(image_path)
def gradient_penalty(discriminator, batch_x, fake_image):
batchsz = batch_x.shape[0]
# [b, h, w, c]
t = tf.random.uniform([batchsz, 1, 1])
# [b, 1, 1, 1] => [b, h, w, c]
t = tf.broadcast_to(t, batch_x.shape)
interplate = t * batch_x + (1 - t) * fake_image
with tf.GradientTape() as tape:
tape.watch([interplate])
d_interplote_logits = discriminator(interplate, training=True)
grads = tape.gradient(d_interplote_logits, interplate)
# grads:[b, h, w, c] => [b, -1]
grads = tf.reshape(grads, [grads.shape[0], -1])
gp = tf.norm(grads, axis=1) # [b]
gp = tf.reduce_mean((gp - 1) ** 2)
return gp
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
# 观察mnist数据类型
print(x_train, y_train, x_test, y_test)
# 4 * 4 * 7 => 28 * 28 * 1
# (60000, 28, 28) - (10000, 28, 28)
generator = Sequential([
layers.Dense(4 * 4 * 7, activation=tf.nn.leaky_relu),
layers.Reshape(target_shape=(4, 4, 7)),
layers.Conv2DTranspose(14, 5, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2DTranspose(5, 3, 1, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2DTranspose(1, 4, 2, activation=tf.nn.tanh),
layers.Reshape(target_shape=(28, 28)),
])
discriminator = Sequential([
layers.Reshape((28, 28, 1)),
layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(1)
])
# 5s 89us/sample - loss: 0.0264 - accuracy: 0.9949 - val_loss: 0.1412 - val_accuracy: 0.9863
# 超参数
dim_h = 100
epochs = int(9e+7)
batch_size = 128
learning_rate = 2e-3
def preprocess(pre_x, pre_y):
pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
pre_y = tf.cast(pre_y, dtype=tf.int32)
return pre_x, pre_y
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
.map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
.map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
generator.build((None, dim_h))
generator.summary()
discriminator.build((None, 28, 28, 1))
discriminator.summary()
# 是不是对应的
print(generator(tf.random.normal((1, dim_h))))
print(discriminator(tf.random.normal((1, 28, 28, 1))))
g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
cross_entropy = losses.BinaryCrossentropy(from_logits=True)
for epoch in range(epochs):
for step, (true_x, y) in enumerate(db_train):
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
random_seek = tf.random.normal((batch_size, dim_h))
false_x = generator(random_seek)
false_y = discriminator(false_x)
true_y = discriminator(true_x)
false_loss = cross_entropy(tf.zeros_like(false_y), false_y)
true_loss = cross_entropy(tf.ones_like(true_y), true_y)
gp = gradient_penalty(discriminator, true_x, false_x)
d_loss = false_loss + true_loss + 1. * gp
g_loss = cross_entropy(tf.ones_like(false_y), false_y)
d_grad = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(d_grad, discriminator.trainable_variables))
g_grad = g_tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(g_grad, generator.trainable_variables))
print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
# 打印一张图片
z = tf.random.normal([100, dim_h])
fake_image = generator(z, training=False)
if not os.path.exists('mnist-images'):
os.mkdir('mnist-images')
img_path = os.path.join('mnist-images', 'wgan-%d.png' % epoch)
fake_image = tf.expand_dims(fake_image, axis=3)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
下面是结果:
wgan-119.png wgan-120.png wgan-121.png wgan-122.png wgan-123.png wgan-124.png
效果还是比较好的,数字样式也比较多。
网友评论