2020机器学习GAN(3)

作者: zidea | 来源:发表于2020-02-26 21:25 被阅读0次

代码讲解

naruto_vs_sasuke.jpg

今天 tensorflow2.0 实现一个简单全连接实现 GAN。我们先回忆在做机器学习一般流程,然后我们根据流程步骤按填空。

  1. 加载数据集
  2. 定义模型
  3. 定义损失函数
  4. 定义优化函数
  5. 训练模型
  6. 预测
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

import numpy as np
import glob
import os

%matplotlib inline

这里使用

tf.__version__
'2.0.0'

加载数据集

考虑我的小笔记本性能,我只能在比较简单数据集上进行训练模型,这里使用即使 mnist 手写数字数据,这么经典我就不多说了。

(train_data,train_labels),(_,_) = tf.keras.datasets.mnist.load_data()
print(train_data.shape)
print(train_labels.shape)
(60000, 28, 28)
(60000,)
plt.imshow(train_data[0])
<matplotlib.image.AxesImage at 0x14322cd50>
output_7_1.png

数据处理

  • 改变数据类型
  • 数据进行归一化
    这里我们需要把图片进行变形,数据形状从(60000,28,28)变为(60000,28,28,1),数据类型从 int 变为 float32 便于计算机进行计算。
# 因为计算机对于浮点数据进行计算比较舒服
train_data = train_data.reshape(train_data.shape[0],28,28,1).astype('float32')
# 通过对数据进行缩放来实现归一化
train_data = (train_data - 127.5)/127.5
# plt.imshow(train_data[0])

批次也就是每一个训练迭代输入 256 张图片来完成一次训练

BATCH_SIZE = 256
BUFFER_SIZE = 60000

dataset 是 tensorflow 新的数据 api,很好用,也是tensorflow 2.0 让人眼前一亮新功能

dataset = tf.data.Dataset.from_tensor_slices(train_data)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset
<BatchDataset shapes: (None, 28, 28, 1), types: tf.float32>

定义生成器(generator)

naruto_generator.jpeg

生成器就是要骗过判别器,生成图片让判别器无法识别出是生成器伪造的,而误认为是真实图片。我们这里用了两层全连接将 100 维度向量转换 784 维,然后通过 Reshape 转换图片格式(28,28,1)

def build_generator():
    model = tf.keras.Sequential()
    # 随机向量使用 100 维度向量     
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28,28,1)))
    
    return model
    
    

定义判别模型

<img src="images/sasuke_discriminator.png" width="50%"/>


sakuke_discriminator.png

判别器模型主要目的就是从真实图片辨别出生成器图片,也可以看为教练,不断督促判别器做的更好,开始我们用 Flatten 将图片展平,最终输出 1 表示 0 到 1。

def build_discriminator():
    model = tf.keras.Sequential()
    # 首先我们将图片进行 flatten
    model.add(layers.Flatten())
    
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(256,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    # 规范到 0 - 1 ,小于 0.5 我们,     
    model.add(layers.Dense(1))
    
    return model
    

超参数

EPOCHS = 100
noise_dim = 100
learning_rate = 1e-4
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
# 生成图片是否为真实图片
# 因为没有激活所以将 from_logits True
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

定损失函数

在判别器损失函数,当输入真实数据集图片时,我们给出 1 ,而输入为生成器的生成图时给出 0

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    
    return real_loss + fake_loss
在生成器损失函数
def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

优化器

evaluation.jpeg

优化器对于判别器和生成器都使用 Adam 。

generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator = build_discriminator()
generator = build_generator()

相关文章

  • 2020机器学习GAN(3)

    代码讲解 今天 tensorflow2.0 实现一个简单全连接实现 GAN。我们先回忆在做机器学习一般流程,然后我...

  • 2020机器学习GAN(5)

    参考李宏毅老师分享 先解释一下根据样本进行生成图片,然后把图片每一个像素都看成其一个特征,那么图片就是存在一个高维...

  • 2020机器学习GAN(8)

    为什么不用 Discriminator 生成图片 其实 Discriminator 会从整体来考量图片好坏假设我们...

  • 2020机器学习GAN(1)

    在 2018 年,深度学习掀起 GAN 的潮流,在许多领域都已经有了应用,而且基于 GAN 项目五花八门。早在**...

  • 2020机器学习GAN(F)

    训练 好现在定义好判别器模型和生成模型,而且定义好优化器这里梯度下降使用 Adam 给以较小学习率。在开始训练前我...

  • 2020机器学习GAN(6)

    我们就是找到一个 G ,这个 G 可以让和之间的 JSDiv 距离尽量小。 回忆一下我们在 GAN 中任务,就是找...

  • 2020机器学习GAN(2)

    今天目标是介绍一下 GAN 是如何做到输出图片的这个样任务。在 GAN 中分别有两个阶段,在生成阶段就是固定生成器...

  • 2020机器学习GAN(CGAN) (9)

    CGAN(Conditonal Generation GAN) 什么是 CGAN 呢?之前我们 GAN 生成图片是...

  • 2020机器学习GAN(WGAN) (9)

    我们之前都用 JS 散度来衡量两笔数据之间差异,今天我们就来先讨论一下 JS 散度作为衡量工具是否合适的问题。我们...

  • 深入浅出对抗生成网络(GAN)

    GAN是机器学习的大牛Ian Goodfellow于2014年提出,是一种无监督学习的实现方案。GAN一个比较接地...

网友评论

    本文标题:2020机器学习GAN(3)

    本文链接:https://www.haomeiwen.com/subject/ypxgchtx.html