美文网首页
DCGAN学习

DCGAN学习

作者: callme周小伦 | 来源:发表于2019-05-15 15:23 被阅读0次

生成对抗网络GAN是最近比较热的方向,这里依照DCGAN TUTORIAL来进行DCGAN的编写。

  1. 首先获取需要的参数
parse = argparse.ArgumentParser("GAN learning")
parse.add_argument("--data-root", type=str, default="./celeba",
                   help='data_root')
parse.add_argument('--workers', default=2, type=int,
                   help='worker(default:2)')
parse.add_argument('--batch-size', default=128, type=int,
                   help='batch_size(default:128)')
parse.add_argument('--image-size', default=64, type=int,
                   help='image_size')
parse.add_argument('--nc', default=3, type=int,
                   help='numbler of channels in the training image')
parse.add_argument('--nz', default=100, type=int,
                   help='size of z latent vector')
parse.add_argument('--ngf', default=64, type=int,
                   help='size of feature maps in generator')
parse.add_argument('--ndf', default=64, type=int,
                   help='size of feature maps in discriminator')
parse.add_argument('--num-epoch', default=5, type=int,
                   help='number of epoch')
parse.add_argument('--lr', default=0.0002, type=float,
                   help='learning rate')
parse.add_argument('--beta1', default=0.5, type=float,
                   help='beta hyperparam for Adam optimizers')
parse.add_argument('--model', default='./model',type=str, help="model_path")

args = parse.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#为便于复现结果,我们设定随机种子
manualSeed =1
random.seed(manualSeed)
torch.manual_seed(manualSeed)
  1. 进行数据加载及预处理(数据加载及预处理方法:pytorch数据加载及预处理
    这里只有一个文件夹,我们将该文件夹内数据视为一类,因此采用 torchvision.datasets.ImageFolder()进行数据加载,
dataset = dataset.ImageFolder(root=args.data_root,
                              transform=transforms.Compose([
                                  transforms.Resize(args.image_size),
                                  transforms.CenterCrop(args.image_size),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                              ]))
dataloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size,
                        num_workers=args.workers)
#DataLoader所返回的是一个list,dataloader[0]是training image tensor, 而dataloader[1]是training lebels tensor,
#这里面将每个文件夹下的image分为一类,所返回的labels是数值,此处,我们只有一个标签,所以lebels全为0
real_data = next(iter(dataloader))
# plt.figure(figsize=(8,8))
# plt.title("training image")
# plt.axis("off") #去除坐标尺寸
# plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))
# plt.show()

对于plt.imshow(np.transpose(utils.make_grid(real_data[0].to(device)[:64], padding=2, normalize=True).numpy(), (1,2,0)))代码,教程中np.transpose().cpu()但是一直报错,这里我修改成为numpy()后没有问题。

  1. 编写netG、netD参数初始化函数
    这里可以通过__class__.__name__来获取类名
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") !=-1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    if classname.find("BatchNorm") !=-1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
  1. 首先编写生成器 Generator
    这里运用了卷积转置操作(也有称作反卷积),进行卷积的反操作,可参考博客一文搞懂反卷积,转置卷积
    class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(args.nz, args.ngf*8, 4, 1,0, bias=False),
            nn.BatchNorm2d(args.ngf*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(args.ngf*8, args.ngf*4, 4, 2,1, bias=False),
            nn.BatchNorm2d(args.ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(args.ngf*4, args.ngf*2, 4,2,1, bias=False),
            nn.BatchNorm2d(args.ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(args.ngf*2, args.ngf, 4,2,1, bias=False),
            nn.BatchNorm2d(args.ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(args.ngf, args.nc, 4,2,1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)
  1. 判别器 Disceiminator
    注意如果需要修改图片尺寸,判别器卷积核、padding、stride也需要进行对应修改,因为我们最后得到的一个值。
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(args.nc, args.ndf, 4,2,1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(args.ndf, args.ndf*2, 4,2,1, bias=False),
            nn.BatchNorm2d(args.ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(args.ndf*2, args.ndf*4, 4,2,1, bias=False),
            nn.BatchNorm2d(args.ndf*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(args.ndf*4, args.ndf*8, 4,2,1, bias=False),
            nn.BatchNorm2d(args.ndf*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(args.ndf*8, 1, 4,1,0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.main(x)

netD = Discriminator().to(device)
netD.apply(weights_init)

print(netD)
  1. 训练模型,在训练时每个epoch进一个save,并展示real_image、fake_image。
    torch.full(size, fill_value, …) #返回大小为sizes,单位值为fill_value的矩阵
    lebals.fill_(fake_label)#以fake_label填充lebals
    这里有一个特别需要注意的地方:output = netD(fake_batch.detach()).view(-1)
    在判别器对fake数据进行判别时候采用了.detach()函数,如果不添加detach就会报错:
    Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
    查过资料之后在网络中不能进行多个backward,需要在第一个backward设置retain_graph=True,而在GAN中我们分别在Generator、Discriminator中对netG进行了两次backward,因此,我们需要采用detach函数返回一个新的从图中分离的fake_batch(关于debatch更详细的讲解可以参看:pytorch: Variable detach 与 detach_)。
fake = torch.randn(64, args.nz, 1,1, device=device)
def train():
    criterion = nn.BCELoss()

    real_label = 1
    fake_label = 0

    optimizerG = optim.Adam(netG.parameters(), lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr,
                            betas=(args.beta1, 0.999))
    print("start train step:  ")

    G_loss = []
    D_loss = []
    img_lists = []
    iter = 0
    for epoch in range(args.num_epoch):
        for i, data in enumerate(dataloader):
            # 首先我们进行discriminator的训练
            netD.zero_grad()
            real_batch = data[0].to(device)
            bsize = real_batch.size(0)
            lebals = torch.full((bsize,), real_label)
            output = netD(real_batch).view(-1)
            errD_real = criterion(output, lebals)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(bsize, args.nz, 1, 1, device=device)
            fake_batch = netG(noise)
            lebals.fill_(fake_label)
            output = netD(fake_batch.detach()).view(-1)
            errD_fake = criterion(output, lebals)
            errD_fake.backward()
            DG_z1 = output.mean().item()

            errD = errD_real + errD_fake
            optimizerD.step()

            # 接下来进行generator的训练
            netG.zero_grad()
            lebals.fill_(real_label)
            output = netD(fake_batch).view(-1)
            errG = criterion(output, lebals)
            errG.backward()

            DG_z2 = output.mean().item()
            optimizerG.step()

            if i % 50 == 0:
                print(
                    "[%d/%d][%d/%d]\t Loss_D: %.4f\t Loss_G: %.4f\t D(x): %.4f\t"
                    "D(G(z)):%.4f\t/ %.4f\t" % (
                    epoch, args.num_epoch, i, len(dataloader), errD.item(),
                    errG.item(), D_x, DG_z1, DG_z2))
            G_loss.append(errG.item())
            D_loss.append(errD.item())
            if (iter % 500 == 0) or ((epoch == args.num_epoch - 1) and (
                    i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake_ = netG(fake).detach().cpu()
                img_lists.append(
                    utils.make_grid(fake_, padding=2, normalize=True))
            iter += 1

        torch.save(netD.state_dict(),
                   os.path.join(args.model, "netD_{}.pth".format(epoch)))
        torch.save(netG.state_dict(),
                   os.path.join(args.model, "netG_{}.pth".format(epoch)))

        plt.figure(figsize=(15, 15))
        plt.subplot(1, 2, 1)
        plt.axis("off")
        plt.title("Real Images")
        plt.imshow(np.transpose(
            utils.make_grid(real_data[0].to(device)[:64], padding=5,
                            normalize=True).numpy(), (1, 2, 0)))
        plt.subplot(1, 2, 2)
        plt.axis("off")
        plt.title("Fake Images")
        plt.imshow(np.transpose(img_lists[-1].numpy(), (1, 2, 0)))
        plt.show()

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_loss, label="G")
    plt.plot(D_loss, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.show()

相关文章

  • DCGAN学习

    生成对抗网络GAN是最近比较热的方向,这里依照DCGAN TUTORIAL来进行DCGAN的编写。 首先获取需要的...

  • 会PS的人工智能上课笔记

    DCGAN

  • TensorFlow应用实战-11-DCGAN介绍及原理

    什么是DCGAN DCGAN是GAN的一个变体。 Deep Convolutional GAN 深度卷积生成对抗网...

  • GAN和DCGAN

    原文地址 zhuanlan.zhihu.com GAN 学习:GAN/DCGAN 原理、公式推导及实践 生成对抗网...

  • DCGAN

    [DCGAN]的全称是Deep Convolution Generative Adversarial Networ...

  • DCGAN

    论文:Unsupervised Representations Learning With Deep Convol...

  • Tensorflow神经网络之DCGAN

    DCGAN简介 DCGAN在GAN的基础上优化了网络结构,加入了 conv,batch_norm 等层,使得网络更...

  • 深度卷积生成对抗网络(DCGAN)

    DCGAN DCGAN相对于原始的GAN并没有太大的改进,只是将全卷积神经网络应用到了GAN中,因此GAN存在的许...

  • 【DCGAN】原理

    传统CNN: 卷积层:提取特征 池化层:特征下采样 DCGAN:没有池化层了,都是卷积层

  • 学习笔记GAN002:DCGAN

    Ian J. Goodfellow 论文:https://arxiv.org/abs/1406.2661 两个网络...

网友评论

      本文标题:DCGAN学习

      本文链接:https://www.haomeiwen.com/subject/qkmpaqtx.html