美文网首页
手写数字生成mnist wgan

手写数字生成mnist wgan

作者: small瓜瓜 | 来源:发表于2020-04-15 11:01 被阅读0次

前段时间测试了下dcgan生成mnist,结果发现最后的生成的图片并不是那么好,收敛较慢,生成的图片中有很多个零等缺点,这里测试wgan的效果。
代码如下:

import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers, losses
from PIL import Image
import numpy as np
import os


def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        # img = img.astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b + 1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image).save(image_path)


def gradient_penalty(discriminator, batch_x, fake_image):
    batchsz = batch_x.shape[0]

    # [b, h, w, c]
    t = tf.random.uniform([batchsz, 1, 1])
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)

    interplate = t * batch_x + (1 - t) * fake_image

    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplote_logits = discriminator(interplate, training=True)
    grads = tape.gradient(d_interplote_logits, interplate)

    # grads:[b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1)  # [b]
    gp = tf.reduce_mean((gp - 1) ** 2)

    return gp


(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# 观察mnist数据类型
print(x_train, y_train, x_test, y_test)

# 4 * 4 * 7 => 28 * 28 * 1
# (60000, 28, 28) - (10000, 28, 28)
generator = Sequential([
    layers.Dense(4 * 4 * 7, activation=tf.nn.leaky_relu),
    layers.Reshape(target_shape=(4, 4, 7)),
    layers.Conv2DTranspose(14, 5, 2, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2DTranspose(5, 3, 1, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2DTranspose(1, 4, 2, activation=tf.nn.tanh),
    layers.Reshape(target_shape=(28, 28)),
])

discriminator = Sequential([
    layers.Reshape((28, 28, 1)),
    layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
    layers.BatchNormalization(),
    layers.Flatten(),
    layers.Dense(1)
])
# 5s 89us/sample - loss: 0.0264 - accuracy: 0.9949 - val_loss: 0.1412 - val_accuracy: 0.9863

# 超参数
dim_h = 100
epochs = int(9e+7)
batch_size = 128
learning_rate = 2e-3


def preprocess(pre_x, pre_y):
    pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
    pre_y = tf.cast(pre_y, dtype=tf.int32)
    return pre_x, pre_y


db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
    .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)

db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
    .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)

generator.build((None, dim_h))
generator.summary()

discriminator.build((None, 28, 28, 1))
discriminator.summary()

# 是不是对应的
print(generator(tf.random.normal((1, dim_h))))
print(discriminator(tf.random.normal((1, 28, 28, 1))))

g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
cross_entropy = losses.BinaryCrossentropy(from_logits=True)

for epoch in range(epochs):
    for step, (true_x, y) in enumerate(db_train):
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            random_seek = tf.random.normal((batch_size, dim_h))
            false_x = generator(random_seek)
            false_y = discriminator(false_x)
            true_y = discriminator(true_x)
            false_loss = cross_entropy(tf.zeros_like(false_y), false_y)
            true_loss = cross_entropy(tf.ones_like(true_y), true_y)

            gp = gradient_penalty(discriminator, true_x, false_x)

            d_loss = false_loss + true_loss + 1. * gp
            g_loss = cross_entropy(tf.ones_like(false_y), false_y)
        d_grad = d_tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(d_grad, discriminator.trainable_variables))
        g_grad = g_tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(g_grad, generator.trainable_variables))

    print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
    # 打印一张图片
    z = tf.random.normal([100, dim_h])
    fake_image = generator(z, training=False)
    if not os.path.exists('mnist-images'):
        os.mkdir('mnist-images')
    img_path = os.path.join('mnist-images', 'wgan-%d.png' % epoch)
    fake_image = tf.expand_dims(fake_image, axis=3)
    save_result(fake_image.numpy(), 10, img_path, color_mode='P')

下面是结果:


wgan-119.png wgan-120.png wgan-121.png wgan-122.png wgan-123.png wgan-124.png

效果还是比较好的,数字样式也比较多。

相关文章

网友评论

      本文标题:手写数字生成mnist wgan

      本文链接:https://www.haomeiwen.com/subject/sziqvhtx.html