生成对抗网络GAN是最近比较热的方向,这里依照DCGAN TUTORIAL来进行DCGAN的编写。
- 首先获取需要的参数
parse = argparse.ArgumentParser("GAN learning")
parse.add_argument("--data-root", type=str, default="./celeba",
help='data_root')
parse.add_argument('--workers', default=2, type=int,
help='worker(default:2)')
parse.add_argument('--batch-size', default=128, type=int,
help='batch_size(default:128)')
parse.add_argument('--image-size', default=64, type=int,
help='image_size')
parse.add_argument('--nc', default=3, type=int,
help='numbler of channels in the training image')
parse.add_argument('--nz', default=100, type=int,
help='size of z latent vector')
parse.add_argument('--ngf', default=64, type=int,
help='size of feature maps in generator')
parse.add_argument('--ndf', default=64, type=int,
help='size of feature maps in discriminator')
parse.add_argument('--num-epoch', default=5, type=int,
help='number of epoch')
parse.add_argument('--lr', default=0.0002, type=float,
help='learning rate')
parse.add_argument('--beta1', default=0.5, type=float,
help='beta hyperparam for Adam optimizers')
parse.add_argument('--model', default='./model',type=str, help="model_path")
args = parse.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#为便于复现结果,我们设定随机种子
manualSeed =1
random.seed(manualSeed)
torch.manual_seed(manualSeed)
- 进行数据加载及预处理(数据加载及预处理方法:pytorch数据加载及预处理)
这里只有一个文件夹,我们将该文件夹内数据视为一类,因此采用torchvision.datasets.ImageFolder()
进行数据加载,
dataset = dataset.ImageFolder(root=args.data_root,
transform=transforms.Compose([
transforms.Resize(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]))
dataloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size,
num_workers=args.workers)
#DataLoader所返回的是一个list,dataloader[0]是training image tensor, 而dataloader[1]是training lebels tensor,
#这里面将每个文件夹下的image分为一类,所返回的labels是数值,此处,我们只有一个标签,所以lebels全为0
real_data = next(iter(dataloader))
# plt.figure(figsize=(8,8))
# plt.title("training image")
# plt.axis("off") #去除坐标尺寸
# plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))
# plt.show()
对于plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))
代码,教程中np.transpose().cpu()
但是一直报错,这里我修改成为numpy()
后没有问题。
- 编写netG、netD参数初始化函数
这里可以通过__class__.__name__
来获取类名
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") !=-1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
if classname.find("BatchNorm") !=-1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
- 首先编写生成器
Generator
这里运用了卷积转置操作(也有称作反卷积),进行卷积的反操作,可参考博客一文搞懂反卷积,转置卷积
class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(args.nz, args.ngf*8, 4, 1,0, bias=False),
nn.BatchNorm2d(args.ngf*8),
nn.ReLU(True),
nn.ConvTranspose2d(args.ngf*8, args.ngf*4, 4, 2,1, bias=False),
nn.BatchNorm2d(args.ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(args.ngf*4, args.ngf*2, 4,2,1, bias=False),
nn.BatchNorm2d(args.ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(args.ngf*2, args.ngf, 4,2,1, bias=False),
nn.BatchNorm2d(args.ngf),
nn.ReLU(True),
nn.ConvTranspose2d(args.ngf, args.nc, 4,2,1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)
- 判别器
Disceiminator
注意如果需要修改图片尺寸,判别器卷积核、padding、stride也需要进行对应修改,因为我们最后得到的一个值。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(args.nc, args.ndf, 4,2,1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(args.ndf, args.ndf*2, 4,2,1, bias=False),
nn.BatchNorm2d(args.ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(args.ndf*2, args.ndf*4, 4,2,1, bias=False),
nn.BatchNorm2d(args.ndf*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(args.ndf*4, args.ndf*8, 4,2,1, bias=False),
nn.BatchNorm2d(args.ndf*8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(args.ndf*8, 1, 4,1,0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)
- 训练模型,在训练时每个epoch进一个save,并展示real_image、fake_image。
torch.full(size, fill_value, …) #返回大小为sizes,单位值为fill_value的矩阵
lebals.fill_(fake_label)#以fake_label填充lebals
这里有一个特别需要注意的地方:output = netD(fake_batch.detach()).view(-1)
在判别器对fake数据进行判别时候采用了.detach()函数,如果不添加detach就会报错:
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
查过资料之后在网络中不能进行多个backward,需要在第一个backward设置retain_graph=True
,而在GAN中我们分别在Generator、Discriminator中对netG进行了两次backward,因此,我们需要采用detach函数返回一个新的从图中分离的fake_batch(关于debatch更详细的讲解可以参看:pytorch: Variable detach 与 detach_)。
fake = torch.randn(64, args.nz, 1,1, device=device)
def train():
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
optimizerG = optim.Adam(netG.parameters(), lr=args.lr,
betas=(args.beta1, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=args.lr,
betas=(args.beta1, 0.999))
print("start train step: ")
G_loss = []
D_loss = []
img_lists = []
iter = 0
for epoch in range(args.num_epoch):
for i, data in enumerate(dataloader):
# 首先我们进行discriminator的训练
netD.zero_grad()
real_batch = data[0].to(device)
bsize = real_batch.size(0)
lebals = torch.full((bsize,), real_label)
output = netD(real_batch).view(-1)
errD_real = criterion(output, lebals)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(bsize, args.nz, 1, 1, device=device)
fake_batch = netG(noise)
lebals.fill_(fake_label)
output = netD(fake_batch.detach()).view(-1)
errD_fake = criterion(output, lebals)
errD_fake.backward()
DG_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
# 接下来进行generator的训练
netG.zero_grad()
lebals.fill_(real_label)
output = netD(fake_batch).view(-1)
errG = criterion(output, lebals)
errG.backward()
DG_z2 = output.mean().item()
optimizerG.step()
if i % 50 == 0:
print(
"[%d/%d][%d/%d]\t Loss_D: %.4f\t Loss_G: %.4f\t D(x): %.4f\t"
"D(G(z)):%.4f\t/ %.4f\t" % (
epoch, args.num_epoch, i, len(dataloader), errD.item(),
errG.item(), D_x, DG_z1, DG_z2))
G_loss.append(errG.item())
D_loss.append(errD.item())
if (iter % 500 == 0) or ((epoch == args.num_epoch - 1) and (
i == len(dataloader) - 1)):
with torch.no_grad():
fake_ = netG(fake).detach().cpu()
img_lists.append(
utils.make_grid(fake_, padding=2, normalize=True))
iter += 1
torch.save(netD.state_dict(),
os.path.join(args.model, "netD_{}.pth".format(epoch)))
torch.save(netG.state_dict(),
os.path.join(args.model, "netG_{}.pth".format(epoch)))
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(
utils.make_grid(real_data[0].to(device)[:64], padding=5,
normalize=True).numpy(), (1, 2, 0)))
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_lists[-1].numpy(), (1, 2, 0)))
plt.show()
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss, label="G")
plt.plot(D_loss, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.show()
网友评论