美文网首页
手写数字生成mnist wgan

手写数字生成mnist wgan

作者: small瓜瓜 | 来源:发表于2020-04-15 11:01 被阅读0次

    前段时间测试了下dcgan生成mnist,结果发现最后的生成的图片并不是那么好,收敛较慢,生成的图片中有很多个零等缺点,这里测试wgan的效果。
    代码如下:

    import tensorflow as tf
    from tensorflow.keras import datasets, Sequential, layers, losses
    from PIL import Image
    import numpy as np
    import os
    
    
    def save_result(val_out, val_block_size, image_path, color_mode):
        def preprocess(img):
            img = ((img + 1.0) * 127.5).astype(np.uint8)
            # img = img.astype(np.uint8)
            return img
    
        preprocesed = preprocess(val_out)
        final_image = np.array([])
        single_row = np.array([])
        for b in range(val_out.shape[0]):
            # concat image into a row
            if single_row.size == 0:
                single_row = preprocesed[b, :, :, :]
            else:
                single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
    
            # concat image row to final_image
            if (b + 1) % val_block_size == 0:
                if final_image.size == 0:
                    final_image = single_row
                else:
                    final_image = np.concatenate((final_image, single_row), axis=0)
    
                # reset single row
                single_row = np.array([])
    
        if final_image.shape[2] == 1:
            final_image = np.squeeze(final_image, axis=2)
        Image.fromarray(final_image).save(image_path)
    
    
    def gradient_penalty(discriminator, batch_x, fake_image):
        batchsz = batch_x.shape[0]
    
        # [b, h, w, c]
        t = tf.random.uniform([batchsz, 1, 1])
        # [b, 1, 1, 1] => [b, h, w, c]
        t = tf.broadcast_to(t, batch_x.shape)
    
        interplate = t * batch_x + (1 - t) * fake_image
    
        with tf.GradientTape() as tape:
            tape.watch([interplate])
            d_interplote_logits = discriminator(interplate, training=True)
        grads = tape.gradient(d_interplote_logits, interplate)
    
        # grads:[b, h, w, c] => [b, -1]
        grads = tf.reshape(grads, [grads.shape[0], -1])
        gp = tf.norm(grads, axis=1)  # [b]
        gp = tf.reduce_mean((gp - 1) ** 2)
    
        return gp
    
    
    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
    
    # 观察mnist数据类型
    print(x_train, y_train, x_test, y_test)
    
    # 4 * 4 * 7 => 28 * 28 * 1
    # (60000, 28, 28) - (10000, 28, 28)
    generator = Sequential([
        layers.Dense(4 * 4 * 7, activation=tf.nn.leaky_relu),
        layers.Reshape(target_shape=(4, 4, 7)),
        layers.Conv2DTranspose(14, 5, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2DTranspose(5, 3, 1, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2DTranspose(1, 4, 2, activation=tf.nn.tanh),
        layers.Reshape(target_shape=(28, 28)),
    ])
    
    discriminator = Sequential([
        layers.Reshape((28, 28, 1)),
        layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Flatten(),
        layers.Dense(1)
    ])
    # 5s 89us/sample - loss: 0.0264 - accuracy: 0.9949 - val_loss: 0.1412 - val_accuracy: 0.9863
    
    # 超参数
    dim_h = 100
    epochs = int(9e+7)
    batch_size = 128
    learning_rate = 2e-3
    
    
    def preprocess(pre_x, pre_y):
        pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
        pre_y = tf.cast(pre_y, dtype=tf.int32)
        return pre_x, pre_y
    
    
    db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
        .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
    
    db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
        .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
    
    generator.build((None, dim_h))
    generator.summary()
    
    discriminator.build((None, 28, 28, 1))
    discriminator.summary()
    
    # 是不是对应的
    print(generator(tf.random.normal((1, dim_h))))
    print(discriminator(tf.random.normal((1, 28, 28, 1))))
    
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    cross_entropy = losses.BinaryCrossentropy(from_logits=True)
    
    for epoch in range(epochs):
        for step, (true_x, y) in enumerate(db_train):
            with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
                random_seek = tf.random.normal((batch_size, dim_h))
                false_x = generator(random_seek)
                false_y = discriminator(false_x)
                true_y = discriminator(true_x)
                false_loss = cross_entropy(tf.zeros_like(false_y), false_y)
                true_loss = cross_entropy(tf.ones_like(true_y), true_y)
    
                gp = gradient_penalty(discriminator, true_x, false_x)
    
                d_loss = false_loss + true_loss + 1. * gp
                g_loss = cross_entropy(tf.ones_like(false_y), false_y)
            d_grad = d_tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(zip(d_grad, discriminator.trainable_variables))
            g_grad = g_tape.gradient(g_loss, generator.trainable_variables)
            g_optimizer.apply_gradients(zip(g_grad, generator.trainable_variables))
    
        print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
        # 打印一张图片
        z = tf.random.normal([100, dim_h])
        fake_image = generator(z, training=False)
        if not os.path.exists('mnist-images'):
            os.mkdir('mnist-images')
        img_path = os.path.join('mnist-images', 'wgan-%d.png' % epoch)
        fake_image = tf.expand_dims(fake_image, axis=3)
        save_result(fake_image.numpy(), 10, img_path, color_mode='P')
    

    下面是结果:


    wgan-119.png wgan-120.png wgan-121.png wgan-122.png wgan-123.png wgan-124.png

    效果还是比较好的,数字样式也比较多。

    相关文章

      网友评论

          本文标题:手写数字生成mnist wgan

          本文链接:https://www.haomeiwen.com/subject/sziqvhtx.html