美文网首页
DCGAN学习

DCGAN学习

作者: callme周小伦 | 来源:发表于2019-05-15 15:23 被阅读0次

    生成对抗网络GAN是最近比较热的方向,这里依照DCGAN TUTORIAL来进行DCGAN的编写。

    1. 首先获取需要的参数
    parse = argparse.ArgumentParser("GAN learning")
    parse.add_argument("--data-root", type=str, default="./celeba",
                       help='data_root')
    parse.add_argument('--workers', default=2, type=int,
                       help='worker(default:2)')
    parse.add_argument('--batch-size', default=128, type=int,
                       help='batch_size(default:128)')
    parse.add_argument('--image-size', default=64, type=int,
                       help='image_size')
    parse.add_argument('--nc', default=3, type=int,
                       help='numbler of channels in the training image')
    parse.add_argument('--nz', default=100, type=int,
                       help='size of z latent vector')
    parse.add_argument('--ngf', default=64, type=int,
                       help='size of feature maps in generator')
    parse.add_argument('--ndf', default=64, type=int,
                       help='size of feature maps in discriminator')
    parse.add_argument('--num-epoch', default=5, type=int,
                       help='number of epoch')
    parse.add_argument('--lr', default=0.0002, type=float,
                       help='learning rate')
    parse.add_argument('--beta1', default=0.5, type=float,
                       help='beta hyperparam for Adam optimizers')
    parse.add_argument('--model', default='./model',type=str, help="model_path")
    
    args = parse.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #为便于复现结果,我们设定随机种子
    manualSeed =1
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    
    1. 进行数据加载及预处理(数据加载及预处理方法:pytorch数据加载及预处理
      这里只有一个文件夹,我们将该文件夹内数据视为一类,因此采用 torchvision.datasets.ImageFolder()进行数据加载,
    dataset = dataset.ImageFolder(root=args.data_root,
                                  transform=transforms.Compose([
                                      transforms.Resize(args.image_size),
                                      transforms.CenterCrop(args.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                                  ]))
    dataloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size,
                            num_workers=args.workers)
    #DataLoader所返回的是一个list,dataloader[0]是training image tensor, 而dataloader[1]是training lebels tensor,
    #这里面将每个文件夹下的image分为一类,所返回的labels是数值,此处,我们只有一个标签,所以lebels全为0
    real_data = next(iter(dataloader))
    # plt.figure(figsize=(8,8))
    # plt.title("training image")
    # plt.axis("off") #去除坐标尺寸
    # plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))
    # plt.show()
    

    对于plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))代码,教程中np.transpose().cpu()但是一直报错,这里我修改成为numpy()后没有问题。

    1. 编写netG、netD参数初始化函数
      这里可以通过__class__.__name__来获取类名
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv") !=-1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if classname.find("BatchNorm") !=-1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    1. 首先编写生成器 Generator
      这里运用了卷积转置操作(也有称作反卷积),进行卷积的反操作,可参考博客一文搞懂反卷积,转置卷积
      class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True)
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.main = nn.Sequential(
                nn.ConvTranspose2d(args.nz, args.ngf*8, 4, 1,0, bias=False),
                nn.BatchNorm2d(args.ngf*8),
                nn.ReLU(True),
    
                nn.ConvTranspose2d(args.ngf*8, args.ngf*4, 4, 2,1, bias=False),
                nn.BatchNorm2d(args.ngf*4),
                nn.ReLU(True),
    
                nn.ConvTranspose2d(args.ngf*4, args.ngf*2, 4,2,1, bias=False),
                nn.BatchNorm2d(args.ngf*2),
                nn.ReLU(True),
    
                nn.ConvTranspose2d(args.ngf*2, args.ngf, 4,2,1, bias=False),
                nn.BatchNorm2d(args.ngf),
                nn.ReLU(True),
    
                nn.ConvTranspose2d(args.ngf, args.nc, 4,2,1, bias=False),
                nn.Tanh()
            )
    
        def forward(self, x):
            return self.main(x)
    
    # Create the generator
    netG = Generator().to(device)
    netG.apply(weights_init)
    print(netG)
    
    1. 判别器 Disceiminator
      注意如果需要修改图片尺寸,判别器卷积核、padding、stride也需要进行对应修改,因为我们最后得到的一个值。
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.main = nn.Sequential(
                nn.Conv2d(args.nc, args.ndf, 4,2,1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(args.ndf, args.ndf*2, 4,2,1, bias=False),
                nn.BatchNorm2d(args.ndf*2),
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(args.ndf*2, args.ndf*4, 4,2,1, bias=False),
                nn.BatchNorm2d(args.ndf*4),
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(args.ndf*4, args.ndf*8, 4,2,1, bias=False),
                nn.BatchNorm2d(args.ndf*8),
                nn.LeakyReLU(0.2, inplace=True),
    
                nn.Conv2d(args.ndf*8, 1, 4,1,0, bias=False),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.main(x)
    
    netD = Discriminator().to(device)
    netD.apply(weights_init)
    
    print(netD)
    
    1. 训练模型,在训练时每个epoch进一个save,并展示real_image、fake_image。
      torch.full(size, fill_value, …) #返回大小为sizes,单位值为fill_value的矩阵
      lebals.fill_(fake_label)#以fake_label填充lebals
      这里有一个特别需要注意的地方:output = netD(fake_batch.detach()).view(-1)
      在判别器对fake数据进行判别时候采用了.detach()函数,如果不添加detach就会报错:
      Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
      查过资料之后在网络中不能进行多个backward,需要在第一个backward设置retain_graph=True,而在GAN中我们分别在Generator、Discriminator中对netG进行了两次backward,因此,我们需要采用detach函数返回一个新的从图中分离的fake_batch(关于debatch更详细的讲解可以参看:pytorch: Variable detach 与 detach_)。
    fake = torch.randn(64, args.nz, 1,1, device=device)
    def train():
        criterion = nn.BCELoss()
    
        real_label = 1
        fake_label = 0
    
        optimizerG = optim.Adam(netG.parameters(), lr=args.lr,
                                betas=(args.beta1, 0.999))
        optimizerD = optim.Adam(netD.parameters(), lr=args.lr,
                                betas=(args.beta1, 0.999))
        print("start train step:  ")
    
        G_loss = []
        D_loss = []
        img_lists = []
        iter = 0
        for epoch in range(args.num_epoch):
            for i, data in enumerate(dataloader):
                # 首先我们进行discriminator的训练
                netD.zero_grad()
                real_batch = data[0].to(device)
                bsize = real_batch.size(0)
                lebals = torch.full((bsize,), real_label)
                output = netD(real_batch).view(-1)
                errD_real = criterion(output, lebals)
                errD_real.backward()
                D_x = output.mean().item()
    
                noise = torch.randn(bsize, args.nz, 1, 1, device=device)
                fake_batch = netG(noise)
                lebals.fill_(fake_label)
                output = netD(fake_batch.detach()).view(-1)
                errD_fake = criterion(output, lebals)
                errD_fake.backward()
                DG_z1 = output.mean().item()
    
                errD = errD_real + errD_fake
                optimizerD.step()
    
                # 接下来进行generator的训练
                netG.zero_grad()
                lebals.fill_(real_label)
                output = netD(fake_batch).view(-1)
                errG = criterion(output, lebals)
                errG.backward()
    
                DG_z2 = output.mean().item()
                optimizerG.step()
    
                if i % 50 == 0:
                    print(
                        "[%d/%d][%d/%d]\t Loss_D: %.4f\t Loss_G: %.4f\t D(x): %.4f\t"
                        "D(G(z)):%.4f\t/ %.4f\t" % (
                        epoch, args.num_epoch, i, len(dataloader), errD.item(),
                        errG.item(), D_x, DG_z1, DG_z2))
                G_loss.append(errG.item())
                D_loss.append(errD.item())
                if (iter % 500 == 0) or ((epoch == args.num_epoch - 1) and (
                        i == len(dataloader) - 1)):
                    with torch.no_grad():
                        fake_ = netG(fake).detach().cpu()
                    img_lists.append(
                        utils.make_grid(fake_, padding=2, normalize=True))
                iter += 1
    
            torch.save(netD.state_dict(),
                       os.path.join(args.model, "netD_{}.pth".format(epoch)))
            torch.save(netG.state_dict(),
                       os.path.join(args.model, "netG_{}.pth".format(epoch)))
    
            plt.figure(figsize=(15, 15))
            plt.subplot(1, 2, 1)
            plt.axis("off")
            plt.title("Real Images")
            plt.imshow(np.transpose(
                utils.make_grid(real_data[0].to(device)[:64], padding=5,
                                normalize=True).numpy(), (1, 2, 0)))
            plt.subplot(1, 2, 2)
            plt.axis("off")
            plt.title("Fake Images")
            plt.imshow(np.transpose(img_lists[-1].numpy(), (1, 2, 0)))
            plt.show()
    
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss During Training")
        plt.plot(G_loss, label="G")
        plt.plot(D_loss, label="D")
        plt.xlabel("iterations")
        plt.ylabel("Loss")
        plt.show()
    

    相关文章

      网友评论

          本文标题:DCGAN学习

          本文链接:https://www.haomeiwen.com/subject/qkmpaqtx.html