美文网首页人工不智能
用Pytorch实现WGAN

用Pytorch实现WGAN

作者: A君来了 | 来源:发表于2018-08-29 15:10 被阅读424次

    本文是解读WGAN的实践篇,目标是用pytorch实现能生成人脸图像的WGAN。如果对WGAN、DCGANs和GANs还不熟悉的话,可以先阅读解读WGAN这篇理论博文,本文不再详解其原理架构,完整的源代码请查看https://github.com/alexshuang/wgan-pytorch

    Looking at the code

    Notebook / Paper

    我们知道GANs可以用来做很多事情,像生成音频、文字、图像、视频等,本文目标就是用GANs来生成人脸图像,我们将会用到CelebA,一个提供名人头像的数据集。

    Dataset

    一个神经网络模型主要由Dataset、Model和Loss function这三部分构成,首先我们先看Dataset:

    你可以通过CelebA官网上的指引来下载数据集,当然也可以像我这样使用kaggle的数据集:
    kaggle datasets download -d jessicali9530/celeba-dataset

    我是基于fast.ai深度学习框架来编写程序的,它为加载dataset提供了从csv文件加载数据的API,所以我们需要先生成dataset的metadata--files.csv:

    files = PATH.glob('**/*.jpg')
    with CSV_PATH.open('w') as fp:
      for f in files: fp.write(f'{f.relative_to(IMG_PATH)},0\n')
    

    因为GANs是不需要指定label的,所以filename后面跟着的是0。

    bs = 64
    sz = 64
    
    tfms = tfms_from_stats(inception_stats, sz)
    md = ImageClassifierData.from_csv(PATH, TRN_DN, CSV_PATH, bs, tfms=tfms, skip_header=False, continuous=True)
    

    ImageClassifierData.from_csv()就是上文所提到的“从csv文件加载数据的API”。ImageClassifierData是继承于pytorch dataset的子类,通过trn_dl、val_dl、test_dl这三个迭代器就可以获取相关数据集的样本数据。

    x, _ = next(iter(md.trn_dl))
    x_ims = md.trn_ds.denorm(x)
    
    fig, axes = plt.subplots(3, 4, figsize=(6, 4))
    for i, ax in enumerate(axes.flat):
      ax.imshow(x_ims[i])
    plt.tight_layout()
    
    tou.png

    可以看出数据集是64x64大小的头像,WGAN适用此类低分辨率的图像,而高分辨率的图像则更需要像SRWGAN这类模型,后续有机会我也会继续分享相关模型在不同数据集上的应用。


    Model

    WGAN的作者实现了分别基于GANsDCGANs的两个版本,我这里采用DCGANs的网络架构。

    DCGANs的实现细节:
    image.png

    除此之外,paper里还提到几个细节:

    • generator的输入是长度为100的向量,输出的是rank 4,64x64大小的matrix。
    • 不要将batchnorm部署在generator的输出层和discriminator的输入层。
    • LeakyRelu的"slope of the leak"系数设定为0.2。
    Discriminator
    class ConvBlock(nn.Module):
      def __init__(self, in_c, out_c, ks, stride, bn=True):
        super().__init__()
        padding = ks // 2 // stride
        self.conv = nn.Conv2d(in_c, out_c, ks, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_c) if bn else None
        self.relu = nn.LeakyReLU(0.2, inplace=True)
      
      def forward(self, x):
        x = self.relu(self.conv(x))
        return self.bn(x) if self.bn else x
    
    
    class DCGAN_D(nn.Module):
      def __init__(self, sz, in_c, nf, num_extra_layers=0):
        super().__init__()
        assert sz % 16 == 0, 'sz must be a multipe by 16.'
       
        self.initial = ConvBlock(in_c, nf, 4, 2, bn=False)
        grid_sz = sz / 2
        self.extra = nn.Sequential(*[ConvBlock(nf, nf, 3, 1) for o in range(num_extra_layers)])
        
        layers = []
        while grid_sz > 4:
          layers.append(ConvBlock(nf, nf * 2, 4, 2))
          grid_sz /= 2; nf *= 2
        self.pyramid = nn.Sequential(*layers)
        
        self.final = nn.Conv2d(nf, 1, 4, 1, 0, bias=False)
      
      def forward(self, x):
        x = self.initial(x)
        x = self.extra(x)
        x = self.pyramid(x)
        return self.final(x).mean().view(1)
    

    discriminator是CNN classifier,最终输出一个scalar,该值越小代表fake data的真实系数越高。这里的代码很多都是从原作者的源码中直接借鉴过来的,它通过多个stride convolution操作直至input的grid size等于4,最终取其均值作为output。通常情况下,CNN classifier会以adaptive average pooling生成rank 4, 1x1大小的matrix,将其做flatten操作后作为输入传入全连接层,或许adaptive average pooling也能应用到discriminator中,而且达到更好的效果。

    Generator
    class DeconvBlock(nn.Module):
      def __init__(self, in_c, out_c, ks, stride, pad, bn=True):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_c, out_c, ks, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU(inplace=True)
        
      def forward(self, x):
        x = self.relu(self.conv(x))
        return self.bn(x) if bn else x
    
    
    class DCGAN_G(nn.Module):
      def __init__(self, sz, in_c, nf, num_extra_layers=0):
        super().__init__()
        assert sz % 16 == 0, 'sz must be a multipe by 16.'
        i = 4; nf //= 2
        while i != sz:
          nf *= 2; i *= 2
        
        self.initial = DeconvBlock(in_c, nf, 4, 1, 0)
        
        layers = []
        curr_sz = 4
        while curr_sz < (sz // 2):
          layers.append(DeconvBlock(nf, nf // 2, 4, 2, 1))
          curr_sz *= 2; nf //= 2
        self.pyramid = nn.Sequential(*layers)
        self.extra = nn.Sequential(*[DeconvBlock(nf, nf, 3, 1, 1) for i in range(num_extra_layers)])
        self.final = nn.ConvTranspose2d(nf, 3, 4, 2, 1, bias=False)
      
      def forward(self, x):
        x = self.initial(x)
        x = self.pyramid(x)
        x = self.extra(x)
        return F.tanh(self.final(x))
    

    generator和discriminator相反,通过deconvolution将rank 1的向量转化为rank 4, 64x64的矩阵。rank 1向量就是长度为100的噪声向量,为deconvolution()的参数格式为rank 4,因此需要gen_noise()将rank 1的噪声转化为rank 4矩阵。

    Noise Generator

    def gen_noise(bs):
      return V(torch.zeros((bs, nz, 1, 1)).normal_(0, 1))
    
    image.png

    WGAN / paper

    WGAN.png
    ncritic = 5
    
    def WGAN_fit(md, niter):
      net_G.train(True)
      net_D.train(True)
      gen_iterations = 0
      
      for epoch in range(niter):
        data_iter = iter(md.trn_dl)
        i, n = 0, len(md.trn_dl)
        with tqdm(total=n) as pbar:
          while i < n:
            set_trainable(net_D, True)
            set_trainable(net_G, False)
            j = 0
            while i < n and j < ncritic:
              for p in net_D.parameters(): p.data.clamp_(-0.01, 0.01)
              real = V(next(data_iter)[0])
              real_loss = net_D(real)
              fake = net_G(gen_noise(real.size(0)))
              fake_loss = net_D(V(fake.data))
              net_D.zero_grad()
              d_loss = real_loss - fake_loss
              d_loss.backward()
              opt_D.step()
              i += 1; j += 1
              pbar.update()
            
            set_trainable(net_D, False)
            set_trainable(net_G, True)
            net_G.zero_grad()
            g_loss = net_D(net_G(gen_noise(bs)))
            g_loss.backward()
            opt_G.step()
            gen_iterations += 1
          
        print(f'd_loss: {to_np(d_loss)}; g_loss: {to_np(g_loss)}; '
              f'real_loss: {to_np(real_loss)}; fake_loss: {to_np(fake_loss)}')
    

    paper里已经给出了明确的实现方法以及各超参系数,这里不需要额外多说什么了,如果你对其训练方法或涉及到的公式有所不解,请点击解读WGAN。有经验的朋友看到这会觉得很奇怪,为什么这里没有metric方法?GANs是无监督学习的模型,所以大家都会调侃GANs是一种不用看metric指标的模型哈哈。


    模型训练

    files = PATH.glob('**/*.jpg')
    with CSV_SMALL_PATH.open('w') as fp:
      for f in files:
        if np.random.random() < 0.1: fp.write(f'{f.relative_to(IMG_PATH)},0\n')
    

    为了快速验证模型的有效性,先用10%的数据集来训练。

    lr = 5e-5
    opt_G = optim.RMSprop(net_G.parameters(), lr=lr)
    opt_D = optim.RMSprop(net_D.parameters(), lr=lr)
    WGAN_fit(md, 4)
    
    100%|██████████| 256/256 [01:12<00:00,  1.43it/s]
    d_loss: [-1.06658]; g_loss: [0.81221]; real_loss: [-0.27118]; fake_loss: [0.7954]
    100%|██████████| 256/256 [01:09<00:00,  3.91it/s]
    d_loss: [-1.09587]; g_loss: [0.84473]; real_loss: [-0.26713]; fake_loss: [0.82875]
    100%|██████████| 256/256 [01:10<00:00,  3.89it/s]
    d_loss: [-1.04561]; g_loss: [0.82418]; real_loss: [-0.23558]; fake_loss: [0.81003]
    100%|██████████| 256/256 [01:09<00:00,  3.92it/s]
    d_loss: [-1.06523]; g_loss: [0.8402]; real_loss: [-0.24339]; fake_loss: [0.82184]
    100%|██████████| 256/256 [01:10<00:00,  3.89it/s]
    d_loss: [-1.01388]; g_loss: [0.82767]; real_loss: [-0.22356]; fake_loss: [0.79032]
    

    衡量GANs能力强弱,看fake_loss即可,假设discriminator的鉴别能力会随着generator的造假能力增强而增强,那discriminator的输出--fake_loss,其值越小证明真实程度越高,反之亦然。我们将输出显示出来:

    def gallery(x, nc=3):
        n,h,w,c = x.shape
        nr = n//nc
        assert n == nr*nc
        return (x.reshape(nr, nc, h, w, c)
                  .swapaxes(1,2)
                  .reshape(h*nr, w*nc, c))
    
    fixed_noise = gen_noise(bs)
    
    net_D.eval(); net_G.eval();
    fake = net_G(fixed_noise).data.cpu()
    faked = np.clip(md.trn_ds.denorm(fake),0,1)
    
    plt.figure(figsize=(9,9))
    plt.imshow(gallery(faked, 8))
    
    image.png

    我们看到输出结果很模糊,只有一些轮廓。接着我们再用完整的数据集来训练。

    files = PATH.glob('**/*.jpg')
    with CSV_PATH.open('w') as fp:
      for f in files: fp.write(f'{f.relative_to(IMG_PATH)},0\n')
    
    100%|██████████| 2533/2533 [07:16<00:00,  7.29it/s]
    d_loss: [-0.71795]; g_loss: [0.25149]; real_loss: [-0.2049]; fake_loss: [0.51305]
    100%|██████████| 2533/2533 [07:17<00:00,  7.24it/s]
    d_loss: [-0.66687]; g_loss: [0.22367]; real_loss: [-0.27482]; fake_loss: [0.39205]
    100%|██████████| 2533/2533 [07:15<00:00,  7.34it/s]
    d_loss: [-0.62798]; g_loss: [0.41359]; real_loss: [-0.35229]; fake_loss: [0.27569]
    100%|██████████| 2533/2533 [07:15<00:00,  7.24it/s]
    d_loss: [-0.43533]; g_loss: [0.1597]; real_loss: [-0.01642]; fake_loss: [0.41891]
    100%|██████████| 2533/2533 [07:16<00:00,  7.28it/s]
    d_loss: [-0.4477]; g_loss: [0.42561]; real_loss: [-0.48342]; fake_loss: [-0.03572]
    100%|██████████| 2533/2533 [07:17<00:00,  7.26it/s]
    

    可以看到现在的训练样本是之前的10倍,fake_loss也相比之前要低很多,我们也将输出结果显示出来:


    image.png

    可以看到,虽然只训练了5个epoch,但已经可以看出清晰的人脸,虽然有些人脸很诡异,但鉴于训练时间较短,模型表现上还是OK的。


    引用

    DCGANs
    WGAN: https://arxiv.org/abs/1701.07875
    fasiai: https://github.com/fastai/fastai/blob/master/courses/dl2/wgan.ipynb

    相关文章

      网友评论

        本文标题:用Pytorch实现WGAN

        本文链接:https://www.haomeiwen.com/subject/mbshwftx.html