2020机器学习GAN(3)

作者: zidea | 来源:发表于2020-02-26 21:25 被阅读0次

    代码讲解

    naruto_vs_sasuke.jpg

    今天 tensorflow2.0 实现一个简单全连接实现 GAN。我们先回忆在做机器学习一般流程,然后我们根据流程步骤按填空。

    1. 加载数据集
    2. 定义模型
    3. 定义损失函数
    4. 定义优化函数
    5. 训练模型
    6. 预测
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    import matplotlib.pyplot as plt
    
    import numpy as np
    import glob
    import os
    
    %matplotlib inline
    

    这里使用

    tf.__version__
    
    '2.0.0'
    

    加载数据集

    考虑我的小笔记本性能,我只能在比较简单数据集上进行训练模型,这里使用即使 mnist 手写数字数据,这么经典我就不多说了。

    (train_data,train_labels),(_,_) = tf.keras.datasets.mnist.load_data()
    
    print(train_data.shape)
    print(train_labels.shape)
    
    (60000, 28, 28)
    (60000,)
    
    plt.imshow(train_data[0])
    
    <matplotlib.image.AxesImage at 0x14322cd50>
    
    output_7_1.png

    数据处理

    • 改变数据类型
    • 数据进行归一化
      这里我们需要把图片进行变形,数据形状从(60000,28,28)变为(60000,28,28,1),数据类型从 int 变为 float32 便于计算机进行计算。
    # 因为计算机对于浮点数据进行计算比较舒服
    train_data = train_data.reshape(train_data.shape[0],28,28,1).astype('float32')
    
    # 通过对数据进行缩放来实现归一化
    train_data = (train_data - 127.5)/127.5
    
    # plt.imshow(train_data[0])
    

    批次也就是每一个训练迭代输入 256 张图片来完成一次训练

    BATCH_SIZE = 256
    BUFFER_SIZE = 60000
    

    dataset 是 tensorflow 新的数据 api,很好用,也是tensorflow 2.0 让人眼前一亮新功能

    dataset = tf.data.Dataset.from_tensor_slices(train_data)
    
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    
    dataset
    
    <BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>
    

    定义生成器(generator)

    naruto_generator.jpeg

    生成器就是要骗过判别器,生成图片让判别器无法识别出是生成器伪造的,而误认为是真实图片。我们这里用了两层全连接将 100 维度向量转换 784 维,然后通过 Reshape 转换图片格式(28,28,1)

    def build_generator():
        model = tf.keras.Sequential()
        # 随机向量使用 100 维度向量     
        model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        model.add(layers.Dense(512,use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
        model.add(layers.BatchNormalization())
        
        model.add(layers.Reshape((28,28,1)))
        
        return model
        
        
    

    定义判别模型

    <img src="images/sasuke_discriminator.png" width="50%"/>


    sakuke_discriminator.png

    判别器模型主要目的就是从真实图片辨别出生成器图片,也可以看为教练,不断督促判别器做的更好,开始我们用 Flatten 将图片展平,最终输出 1 表示 0 到 1。

    def build_discriminator():
        model = tf.keras.Sequential()
        # 首先我们将图片进行 flatten
        model.add(layers.Flatten())
        
        model.add(layers.Dense(512,use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        model.add(layers.Dense(256,use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        # 规范到 0 - 1 ,小于 0.5 我们,     
        model.add(layers.Dense(1))
        
        return model
        
    

    超参数

    EPOCHS = 100
    noise_dim = 100
    learning_rate = 1e-4
    num_exp_to_generate = 16
    seed = tf.random.normal([num_exp_to_generate,noise_dim])
    
    # 生成图片是否为真实图片
    # 因为没有激活所以将 from_logits True
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    

    定损失函数

    在判别器损失函数,当输入真实数据集图片时,我们给出 1 ,而输入为生成器的生成图时给出 0

    def discriminator_loss(real_out,fake_out):
        real_loss = cross_entropy(tf.ones_like(real_out),real_out)
        fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
        
        return real_loss + fake_loss
    
    在生成器损失函数
    
    def generator_loss(fake_out):
        return cross_entropy(tf.ones_like(fake_out),fake_out)
    

    优化器

    evaluation.jpeg

    优化器对于判别器和生成器都使用 Adam 。

    generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
    discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
    
    discriminator = build_discriminator()
    
    generator = build_generator()
    

    相关文章

      网友评论

        本文标题:2020机器学习GAN(3)

        本文链接:https://www.haomeiwen.com/subject/ypxgchtx.html