美文网首页
Pytorch GAN生成mnist数字

Pytorch GAN生成mnist数字

作者: dawsonenjoy | 来源:发表于2019-11-13 15:20 被阅读0次

思路

GAN的主要思路是需要一个判别器和生成器,其中判别器需要能够判别真实的数据以及假的数据(将真实的数据传入判别器,希望返回的概率尽量与1接近;将假的数据传入判别器,希望返回的概率尽量与0接近),而生成器则基于传入的随机数据生成一组假的数据,并将这些假数据传入判别器希望概率尽量与1接近,这两者之间是相互博弈的关系,所以需要同时训练判别器和生成器模型,当判别器的判别能力很强,且生成器还能够骗过这么强的判别器时,训练就算是比较成功了。但实际上由于生成器的训练难度明显高于判别器,所以训练判别器和生成器的次数比例需要把握好。

导入模块

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

初始化定义

首先配置一下GPU,然后还有传入生成器模型的初始噪音尺寸,以及训练的batch,代码如下:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

input_size = 100
# 随机生成噪声尺寸
batch_size = 200

数据预处理

数据直接使用mnist集,代码如下:

dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data = data / (255/2) - 1
# 数据控制到-1到1之间

定义网络模型

这里需要分别定义判别器和生成器模型,分别定义如下

判别器模型

判别器模型中,通过传入格式为(batch, 1, 28, 28)的图片,并通过卷积等操作,最终通过全连接层返回该图片为真实图片的概率,模型代码如下:

class DNet(nn.Module):
    # 判别器,识别图片,并返回正确率,越真实的图片正确率尽量接近1,否则接近0
    # input:(batch_size, 1, 28, 28)
    # output:(batch_size, 1)
    def __init__(self):
        super(DNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, 2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, 2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 3, 1, padding=1)
        self.batch_norm1 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
        self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
        self.batch_norm3 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
        self.leakyrelu = nn.LeakyReLU()
        self.linear = nn.Linear(8192, 1)

    def forward(self, x):
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.batch_norm1(self.conv2(x)))
        x = self.leakyrelu(self.batch_norm2(self.conv3(x)))
        x = self.leakyrelu(self.batch_norm3(self.conv4(x)))
        x = torch.flatten(x).reshape(-1, 8192)
        x = torch.sigmoid(self.linear(x))
        return x
生成器模型

在生成器模型中,传入一组随机的噪声,通过逆卷积等操作,生成一组格式为(1, 28, 28)的图片,模型代码如下:

class GNet(nn.Module):
    # 生成器,输入随机噪声,生成尽可能真实的图片
    # input:(batch_size, noise_size)
    # output:(batch_size, 1, 28, 28)
    def __init__(self, input_size):
        super(GNet, self).__init__()
        self.d = 3
        self.linear = nn.Linear(input_size, self.d*self.d*512)
        self.conv_tranpose1 = nn.ConvTranspose2d(512, 256, 5, 2, 1)
        self.conv_tranpose2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.conv_tranpose3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.conv_tranpose4 = nn.ConvTranspose2d(64, 1, 3, 1, 1)
        self.batch_norm1 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
        self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
        self.batch_norm3 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
        self.batch_norm4 = torch.nn.BatchNorm2d(64)#, momentum=0.9)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.linear(x).reshape(-1, 512, self.d, self.d)
        x = self.relu(self.batch_norm1(x))
        x = self.conv_tranpose1(x)
        x = self.relu(self.batch_norm2(x))
        x = self.conv_tranpose2(x)
        x = self.relu(self.batch_norm3(x))
        x = self.conv_tranpose3(x)
        x = self.relu(self.batch_norm4(x))
        x = self.tanh(self.conv_tranpose4(x))
        return x 

定义损失函数和优化器

由于判别器只有0到1的概率结果,所以这里使用二分类交叉熵计算损失,代码如下:

loss_fun = nn.BCELoss()
# 使用BCE计算损失,如果用MSE的话很难收敛,而且计算也很慢
goptim = torch.optim.Adam(gmodel.parameters(), lr=0.0001)
doptim = torch.optim.Adam(dmodel.parameters(), lr=0.0001)

训练模型

训练模型时,判别器的损失为真实图片返回的概率与1的差加上假的图片返回的概率与0的差之和(两个差越接近0时说明判别能力越强),而生成器的损失为通过生成器生成的图片在判别器中返回的概率与1的差(生成的图片概率越接近1说明越真实)。并且由于生成器模型的训练难度更大,所以这里每训练一次判别器就训练三次生成器,代码如下:

dmodel.train()
gmodel.train()
li_gloss = []
li_dloss = []

d_true = torch.ones(batch_size, 1).to(device)
d_fake = torch.zeros(batch_size, 1).to(device)
for epoch in range(30):
    for batch in range(0, 60000, batch_size):
        batch_data = data[batch:batch+batch_size].to(device)
        # mnist集取出的真数据
        fake_data = torch.randn(batch_size, input_size).to(device)
        # 随机生成的假数据
        output_dtrue = dmodel(batch_data)
        # 先用判别器判别真数据
        loss_dtrue = loss_fun(output_dtrue, d_true)
        # 真数据的判别结果和1越近越好
        output_dfake = dmodel(gmodel(fake_data))
        # 再用判别器来判别通过假数据生成的图片
        loss_dfake = loss_fun(output_dfake, d_fake)
        # 对于判别器来说,假数据生成的图片的判别结果和0越近越好
        loss_d = loss_dtrue + loss_dfake
        # 两者的loss都是越小越好
        doptim.zero_grad()
        loss_d.backward()
        doptim.step()
        li_dloss.append(loss_d)
        # 在判别器有了判别能力以后开始训练生成器
        for i in range(3):
            # 因为生成器更难训练,所以每训练一次判别器就训练5次生成器
            # fake_data = torch.randn(batch_size, input_size).to(device)
            output_gtrue = dmodel(gmodel(fake_data))
            # 判别通过假数据生成的图片
            loss_g = loss_fun(output_gtrue, d_true)
            # 对于生成器来说,生成的图片的判别结果越接近1越好,也就是越接近原图越好
            doptim.zero_grad()
            goptim.zero_grad()
            loss_g.backward()
            goptim.step()
            li_gloss.append(loss_g)
        print("epoch:{}, batch:{}, loss_d:{}, loss_g:{}".format(epoch, batch, loss_d, loss_g))
        torch.save(dmodel.state_dict(), "gan_dmodel.mdl")
        torch.save(gmodel.state_dict(), "gan_gmodel.mdl")
        if batch / batch_size % 30 == 0 and batch != 0:
            plt.plot(li_dloss)
            plt.show()
            plt.plot(li_gloss)
            plt.show()

训练个20/30轮结果已经有点样子了,如果时间充足建议训练50轮以上,这里是训练了30轮结果的loss变化趋势:


判别器损失变化
生成器损失变化

载入模型

如果想要在本地运行可以将跑好的模型在本地载入,代码如下:

gmodel.load_state_dict(torch.load("gan_gmodel.mdl"))
dmodel.load_state_dict(torch.load("gan_dmodel.mdl"))

测试模型

产生batch的数据测试看看结果如何,代码如下:

gmodel.eval()
data_test = torch.randn(100, 100)
result = gmodel(data_test.to(device))
plt.figure(figsize=(10, 50))
for i in range(len(result)):
    ax = plt.subplot(len(result) / 5, 5, i+1)
    plt.imshow((result[i].cpu().data.reshape(28, 28)+1)*255/2)
    # plt.gray()
plt.show()

训练30轮后,生成器生成的图像结果如下:


30轮生成图片

50轮的图像结果如下:


50轮生成图片

完整代码

# -----------------------------
# 导入模块
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# 设置gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

# -----------------------------
# 基本参数设置
input_size = 100
# 随机生成噪声尺寸
batch_size = 200

# -----------------------------
# 数据预处理
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data = data / (255/2) - 1
# 数据控制到-1到1之间

# -----------------------------
# 定义网络
class DNet(nn.Module):
    # 判别器,识别图片,并返回正确率,越真实的图片正确率尽量接近1,否则接近0
    # input:(batch_size, 1, 28, 28)
    # output:(batch_size, 1)
    def __init__(self):
        super(DNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, 2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, 2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, 2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 3, 1, padding=1)
        self.batch_norm1 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
        self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
        self.batch_norm3 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
        self.leakyrelu = nn.LeakyReLU()
        self.linear = nn.Linear(8192, 1)

    def forward(self, x):
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.batch_norm1(self.conv2(x)))
        x = self.leakyrelu(self.batch_norm2(self.conv3(x)))
        x = self.leakyrelu(self.batch_norm3(self.conv4(x)))
        x = torch.flatten(x).reshape(-1, 8192)
        x = torch.sigmoid(self.linear(x))
        return x

class GNet(nn.Module):
    # 生成器,输入随机噪声,生成尽可能真实的图片
    # input:(batch_size, noise_size)
    # output:(batch_size, 1, 28, 28)
    def __init__(self, input_size):
        super(GNet, self).__init__()
        self.d = 3
        self.linear = nn.Linear(input_size, self.d*self.d*512)
        self.conv_tranpose1 = nn.ConvTranspose2d(512, 256, 5, 2, 1)
        self.conv_tranpose2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.conv_tranpose3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.conv_tranpose4 = nn.ConvTranspose2d(64, 1, 3, 1, 1)
        self.batch_norm1 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
        self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
        self.batch_norm3 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
        self.batch_norm4 = torch.nn.BatchNorm2d(64)#, momentum=0.9)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.linear(x).reshape(-1, 512, self.d, self.d)
        x = self.relu(self.batch_norm1(x))
        x = self.conv_tranpose1(x)
        x = self.relu(self.batch_norm2(x))
        x = self.conv_tranpose2(x)
        x = self.relu(self.batch_norm3(x))
        x = self.conv_tranpose3(x)
        x = self.relu(self.batch_norm4(x))
        x = self.tanh(self.conv_tranpose4(x))
        return x 

dmodel = DNet().to(device)
gmodel = GNet(input_size).to(device)

# -----------------------------
# 损失函数和优化器
loss_fun = nn.BCELoss()
# 使用BCE计算损失,如果用MSE的话很难收敛,而且计算也很慢
goptim = torch.optim.Adam(gmodel.parameters(), lr=0.0001)
doptim = torch.optim.Adam(dmodel.parameters(), lr=0.0001)

# -----------------------------
# 训练数据
dmodel.train()
gmodel.train()
li_gloss = []
li_dloss = []

d_true = torch.ones(batch_size, 1).to(device)
d_fake = torch.zeros(batch_size, 1).to(device)
for epoch in range(50):
    for batch in range(0, 60000, batch_size):
        batch_data = data[batch:batch+batch_size].to(device)
        # mnist集取出的真数据
        fake_data = torch.randn(batch_size, input_size).to(device)
        # 随机生成的假数据
        output_dtrue = dmodel(batch_data)
        # 先用判别器判别真数据
        loss_dtrue = loss_fun(output_dtrue, d_true)
        # 真数据的判别结果和1越近越好
        output_dfake = dmodel(gmodel(fake_data))
        # 再用判别器来判别通过假数据生成的图片
        loss_dfake = loss_fun(output_dfake, d_fake)
        # 对于判别器来说,假数据生成的图片的判别结果和0越近越好
        loss_d = loss_dtrue + loss_dfake
        # 两者的loss都是越小越好
        doptim.zero_grad()
        loss_d.backward()
        doptim.step()
        li_dloss.append(loss_d)
        # 在判别器有了判别能力以后开始训练生成器
        for i in range(3):
            # 因为生成器更难训练,所以每训练一次判别器就训练5次生成器
            # fake_data = torch.randn(batch_size, input_size).to(device)
            output_gtrue = dmodel(gmodel(fake_data))
            # 判别通过假数据生成的图片
            loss_g = loss_fun(output_gtrue, d_true)
            # 对于生成器来说,生成的图片的判别结果越接近1越好,也就是越接近原图越好
            doptim.zero_grad()
            goptim.zero_grad()
            loss_g.backward()
            goptim.step()
            li_gloss.append(loss_g)
        print("epoch:{}, batch:{}, loss_d:{}, loss_g:{}".format(epoch, batch, loss_d, loss_g))
        torch.save(dmodel.state_dict(), "gan_dmodel.mdl")
        torch.save(gmodel.state_dict(), "gan_gmodel.mdl")
        if batch / batch_size % 30 == 0 and batch != 0:
            plt.plot(li_dloss)
            plt.show()
            plt.plot(li_gloss)
            plt.show()

# -----------------------------
# 载入模型
# gmodel.load_state_dict(torch.load("gan_gmodel.mdl"))
# dmodel.load_state_dict(torch.load("gan_dmodel.mdl"))

# -----------------------------
# 测试数据
gmodel.eval()
data_test = torch.randn(100, 100)
result = gmodel(data_test.to(device))
plt.figure(figsize=(10, 50))
for i in range(len(result)):
    ax = plt.subplot(len(result) / 5, 5, i+1)
    plt.imshow((result[i].cpu().data.reshape(28, 28)+1)*255/2)
    # plt.gray()
plt.show()

相关文章

网友评论

      本文标题:Pytorch GAN生成mnist数字

      本文链接:https://www.haomeiwen.com/subject/igphictx.html