思路
GAN的主要思路是需要一个判别器和生成器,其中判别器需要能够判别真实的数据以及假的数据(将真实的数据传入判别器,希望返回的概率尽量与1接近;将假的数据传入判别器,希望返回的概率尽量与0接近),而生成器则基于传入的随机数据生成一组假的数据,并将这些假数据传入判别器希望概率尽量与1接近,这两者之间是相互博弈的关系,所以需要同时训练判别器和生成器模型,当判别器的判别能力很强,且生成器还能够骗过这么强的判别器时,训练就算是比较成功了。但实际上由于生成器的训练难度明显高于判别器,所以训练判别器和生成器的次数比例需要把握好。
导入模块
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
初始化定义
首先配置一下GPU,然后还有传入生成器模型的初始噪音尺寸,以及训练的batch,代码如下:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
input_size = 100
# 随机生成噪声尺寸
batch_size = 200
数据预处理
数据直接使用mnist集,代码如下:
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data = data / (255/2) - 1
# 数据控制到-1到1之间
定义网络模型
这里需要分别定义判别器和生成器模型,分别定义如下
判别器模型
判别器模型中,通过传入格式为(batch, 1, 28, 28)的图片,并通过卷积等操作,最终通过全连接层返回该图片为真实图片的概率,模型代码如下:
class DNet(nn.Module):
# 判别器,识别图片,并返回正确率,越真实的图片正确率尽量接近1,否则接近0
# input:(batch_size, 1, 28, 28)
# output:(batch_size, 1)
def __init__(self):
super(DNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, 2, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, 2, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, 2, padding=1)
self.conv4 = nn.Conv2d(256, 512, 3, 1, padding=1)
self.batch_norm1 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
self.batch_norm3 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
self.leakyrelu = nn.LeakyReLU()
self.linear = nn.Linear(8192, 1)
def forward(self, x):
x = self.leakyrelu(self.conv1(x))
x = self.leakyrelu(self.batch_norm1(self.conv2(x)))
x = self.leakyrelu(self.batch_norm2(self.conv3(x)))
x = self.leakyrelu(self.batch_norm3(self.conv4(x)))
x = torch.flatten(x).reshape(-1, 8192)
x = torch.sigmoid(self.linear(x))
return x
生成器模型
在生成器模型中,传入一组随机的噪声,通过逆卷积等操作,生成一组格式为(1, 28, 28)的图片,模型代码如下:
class GNet(nn.Module):
# 生成器,输入随机噪声,生成尽可能真实的图片
# input:(batch_size, noise_size)
# output:(batch_size, 1, 28, 28)
def __init__(self, input_size):
super(GNet, self).__init__()
self.d = 3
self.linear = nn.Linear(input_size, self.d*self.d*512)
self.conv_tranpose1 = nn.ConvTranspose2d(512, 256, 5, 2, 1)
self.conv_tranpose2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.conv_tranpose3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.conv_tranpose4 = nn.ConvTranspose2d(64, 1, 3, 1, 1)
self.batch_norm1 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
self.batch_norm3 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
self.batch_norm4 = torch.nn.BatchNorm2d(64)#, momentum=0.9)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.linear(x).reshape(-1, 512, self.d, self.d)
x = self.relu(self.batch_norm1(x))
x = self.conv_tranpose1(x)
x = self.relu(self.batch_norm2(x))
x = self.conv_tranpose2(x)
x = self.relu(self.batch_norm3(x))
x = self.conv_tranpose3(x)
x = self.relu(self.batch_norm4(x))
x = self.tanh(self.conv_tranpose4(x))
return x
定义损失函数和优化器
由于判别器只有0到1的概率结果,所以这里使用二分类交叉熵计算损失,代码如下:
loss_fun = nn.BCELoss()
# 使用BCE计算损失,如果用MSE的话很难收敛,而且计算也很慢
goptim = torch.optim.Adam(gmodel.parameters(), lr=0.0001)
doptim = torch.optim.Adam(dmodel.parameters(), lr=0.0001)
训练模型
训练模型时,判别器的损失为真实图片返回的概率与1的差加上假的图片返回的概率与0的差之和(两个差越接近0时说明判别能力越强),而生成器的损失为通过生成器生成的图片在判别器中返回的概率与1的差(生成的图片概率越接近1说明越真实)。并且由于生成器模型的训练难度更大,所以这里每训练一次判别器就训练三次生成器,代码如下:
dmodel.train()
gmodel.train()
li_gloss = []
li_dloss = []
d_true = torch.ones(batch_size, 1).to(device)
d_fake = torch.zeros(batch_size, 1).to(device)
for epoch in range(30):
for batch in range(0, 60000, batch_size):
batch_data = data[batch:batch+batch_size].to(device)
# mnist集取出的真数据
fake_data = torch.randn(batch_size, input_size).to(device)
# 随机生成的假数据
output_dtrue = dmodel(batch_data)
# 先用判别器判别真数据
loss_dtrue = loss_fun(output_dtrue, d_true)
# 真数据的判别结果和1越近越好
output_dfake = dmodel(gmodel(fake_data))
# 再用判别器来判别通过假数据生成的图片
loss_dfake = loss_fun(output_dfake, d_fake)
# 对于判别器来说,假数据生成的图片的判别结果和0越近越好
loss_d = loss_dtrue + loss_dfake
# 两者的loss都是越小越好
doptim.zero_grad()
loss_d.backward()
doptim.step()
li_dloss.append(loss_d)
# 在判别器有了判别能力以后开始训练生成器
for i in range(3):
# 因为生成器更难训练,所以每训练一次判别器就训练5次生成器
# fake_data = torch.randn(batch_size, input_size).to(device)
output_gtrue = dmodel(gmodel(fake_data))
# 判别通过假数据生成的图片
loss_g = loss_fun(output_gtrue, d_true)
# 对于生成器来说,生成的图片的判别结果越接近1越好,也就是越接近原图越好
doptim.zero_grad()
goptim.zero_grad()
loss_g.backward()
goptim.step()
li_gloss.append(loss_g)
print("epoch:{}, batch:{}, loss_d:{}, loss_g:{}".format(epoch, batch, loss_d, loss_g))
torch.save(dmodel.state_dict(), "gan_dmodel.mdl")
torch.save(gmodel.state_dict(), "gan_gmodel.mdl")
if batch / batch_size % 30 == 0 and batch != 0:
plt.plot(li_dloss)
plt.show()
plt.plot(li_gloss)
plt.show()
训练个20/30轮结果已经有点样子了,如果时间充足建议训练50轮以上,这里是训练了30轮结果的loss变化趋势:
![](https://img.haomeiwen.com/i13183513/04c6fe16b71b50e0.png)
![](https://img.haomeiwen.com/i13183513/1a652bc92330d3b1.png)
载入模型
如果想要在本地运行可以将跑好的模型在本地载入,代码如下:
gmodel.load_state_dict(torch.load("gan_gmodel.mdl"))
dmodel.load_state_dict(torch.load("gan_dmodel.mdl"))
测试模型
产生batch的数据测试看看结果如何,代码如下:
gmodel.eval()
data_test = torch.randn(100, 100)
result = gmodel(data_test.to(device))
plt.figure(figsize=(10, 50))
for i in range(len(result)):
ax = plt.subplot(len(result) / 5, 5, i+1)
plt.imshow((result[i].cpu().data.reshape(28, 28)+1)*255/2)
# plt.gray()
plt.show()
训练30轮后,生成器生成的图像结果如下:
![](https://img.haomeiwen.com/i13183513/a435c1b497ca56b8.png)
50轮的图像结果如下:
![](https://img.haomeiwen.com/i13183513/162f7a4c23020c4a.png)
完整代码
# -----------------------------
# 导入模块
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
# -----------------------------
# 设置gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
# -----------------------------
# 基本参数设置
input_size = 100
# 随机生成噪声尺寸
batch_size = 200
# -----------------------------
# 数据预处理
dataset = datasets.MNIST('data/',download=True)
data = dataset.data.reshape(-1, 1, 28, 28).float()
data = data / (255/2) - 1
# 数据控制到-1到1之间
# -----------------------------
# 定义网络
class DNet(nn.Module):
# 判别器,识别图片,并返回正确率,越真实的图片正确率尽量接近1,否则接近0
# input:(batch_size, 1, 28, 28)
# output:(batch_size, 1)
def __init__(self):
super(DNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, 2, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, 2, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, 2, padding=1)
self.conv4 = nn.Conv2d(256, 512, 3, 1, padding=1)
self.batch_norm1 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
self.batch_norm3 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
self.leakyrelu = nn.LeakyReLU()
self.linear = nn.Linear(8192, 1)
def forward(self, x):
x = self.leakyrelu(self.conv1(x))
x = self.leakyrelu(self.batch_norm1(self.conv2(x)))
x = self.leakyrelu(self.batch_norm2(self.conv3(x)))
x = self.leakyrelu(self.batch_norm3(self.conv4(x)))
x = torch.flatten(x).reshape(-1, 8192)
x = torch.sigmoid(self.linear(x))
return x
class GNet(nn.Module):
# 生成器,输入随机噪声,生成尽可能真实的图片
# input:(batch_size, noise_size)
# output:(batch_size, 1, 28, 28)
def __init__(self, input_size):
super(GNet, self).__init__()
self.d = 3
self.linear = nn.Linear(input_size, self.d*self.d*512)
self.conv_tranpose1 = nn.ConvTranspose2d(512, 256, 5, 2, 1)
self.conv_tranpose2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.conv_tranpose3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.conv_tranpose4 = nn.ConvTranspose2d(64, 1, 3, 1, 1)
self.batch_norm1 = torch.nn.BatchNorm2d(512)#, momentum=0.9)
self.batch_norm2 = torch.nn.BatchNorm2d(256)#, momentum=0.9)
self.batch_norm3 = torch.nn.BatchNorm2d(128)#, momentum=0.9)
self.batch_norm4 = torch.nn.BatchNorm2d(64)#, momentum=0.9)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.linear(x).reshape(-1, 512, self.d, self.d)
x = self.relu(self.batch_norm1(x))
x = self.conv_tranpose1(x)
x = self.relu(self.batch_norm2(x))
x = self.conv_tranpose2(x)
x = self.relu(self.batch_norm3(x))
x = self.conv_tranpose3(x)
x = self.relu(self.batch_norm4(x))
x = self.tanh(self.conv_tranpose4(x))
return x
dmodel = DNet().to(device)
gmodel = GNet(input_size).to(device)
# -----------------------------
# 损失函数和优化器
loss_fun = nn.BCELoss()
# 使用BCE计算损失,如果用MSE的话很难收敛,而且计算也很慢
goptim = torch.optim.Adam(gmodel.parameters(), lr=0.0001)
doptim = torch.optim.Adam(dmodel.parameters(), lr=0.0001)
# -----------------------------
# 训练数据
dmodel.train()
gmodel.train()
li_gloss = []
li_dloss = []
d_true = torch.ones(batch_size, 1).to(device)
d_fake = torch.zeros(batch_size, 1).to(device)
for epoch in range(50):
for batch in range(0, 60000, batch_size):
batch_data = data[batch:batch+batch_size].to(device)
# mnist集取出的真数据
fake_data = torch.randn(batch_size, input_size).to(device)
# 随机生成的假数据
output_dtrue = dmodel(batch_data)
# 先用判别器判别真数据
loss_dtrue = loss_fun(output_dtrue, d_true)
# 真数据的判别结果和1越近越好
output_dfake = dmodel(gmodel(fake_data))
# 再用判别器来判别通过假数据生成的图片
loss_dfake = loss_fun(output_dfake, d_fake)
# 对于判别器来说,假数据生成的图片的判别结果和0越近越好
loss_d = loss_dtrue + loss_dfake
# 两者的loss都是越小越好
doptim.zero_grad()
loss_d.backward()
doptim.step()
li_dloss.append(loss_d)
# 在判别器有了判别能力以后开始训练生成器
for i in range(3):
# 因为生成器更难训练,所以每训练一次判别器就训练5次生成器
# fake_data = torch.randn(batch_size, input_size).to(device)
output_gtrue = dmodel(gmodel(fake_data))
# 判别通过假数据生成的图片
loss_g = loss_fun(output_gtrue, d_true)
# 对于生成器来说,生成的图片的判别结果越接近1越好,也就是越接近原图越好
doptim.zero_grad()
goptim.zero_grad()
loss_g.backward()
goptim.step()
li_gloss.append(loss_g)
print("epoch:{}, batch:{}, loss_d:{}, loss_g:{}".format(epoch, batch, loss_d, loss_g))
torch.save(dmodel.state_dict(), "gan_dmodel.mdl")
torch.save(gmodel.state_dict(), "gan_gmodel.mdl")
if batch / batch_size % 30 == 0 and batch != 0:
plt.plot(li_dloss)
plt.show()
plt.plot(li_gloss)
plt.show()
# -----------------------------
# 载入模型
# gmodel.load_state_dict(torch.load("gan_gmodel.mdl"))
# dmodel.load_state_dict(torch.load("gan_dmodel.mdl"))
# -----------------------------
# 测试数据
gmodel.eval()
data_test = torch.randn(100, 100)
result = gmodel(data_test.to(device))
plt.figure(figsize=(10, 50))
for i in range(len(result)):
ax = plt.subplot(len(result) / 5, 5, i+1)
plt.imshow((result[i].cpu().data.reshape(28, 28)+1)*255/2)
# plt.gray()
plt.show()
网友评论