美文网首页
深度学习(五):生成对抗网络

深度学习(五):生成对抗网络

作者: fromeast | 来源:发表于2019-09-26 10:40 被阅读0次

一、基本原理

生成对抗网络(generative adversarial network, GAN)是通过对抗训练的方式来使得生成网络产生的样本服从真实数据分布。在生成对抗网络中,有 两个网络进行对抗训练。一个是判别网络,目标是尽量准确地判断一个样本是 来自于真实数据还是生成网络产生的;另一个是生成网络,目标是尽量生成判别 网络无法区分来源的样本。这两个目标相反的网络不断地进行交替训练。当最后收敛时,如果判别网络再也无法判断出一个样本的来源,那么也就等价于生成 网络可以生成符合真实数据分布的样本。

生成对抗网络的结构

判别网络的目标即区分样本\mathbf{x}是来自于真实分布p_{r}(\mathbf{x})还是生成模型p_{\theta}(\mathbf{x}),分别用y=1y=0表示,其实质是二分类器,有p(y=1 | \mathbf{x})=D(\mathbf{x}, \phi) p(y=0 | \mathbf{x})=1-D(\mathbf{x}, \phi)。对于给定样本(\mathbf{x},y),目标函数为最小化交叉熵,即最大似然函数。
\begin{aligned} & \min _{\phi}-\left(\mathbb{E}_{\mathbf{x}}[y \log p(y=1 | \mathbf{x})+(1-y) \log p(y=0 | \mathbf{x})]\right) \\=& \max _{\phi}\left(\mathbb{E}_{\mathbf{x} \sim p_{r}(\mathbf{x})}[\log D(\mathbf{x}, \phi)]+\mathbb{E}_{\mathbf{x}^{\prime} \sim p_{\theta}\left(\mathbf{x}^{\prime}\right)}\left[\log \left(1-D\left(\mathbf{x}^{\prime}, \phi\right)\right)\right]\right) \\=& \max _{\phi}\left(\mathbb{E}_{\mathbf{x} \sim p_{r}(\mathbf{x})}[\log D(\mathbf{x}, \phi)]+\mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log (1-D(G(\mathbf{z}, \theta), \phi))]\right) \end{aligned} 其中\phi\theta分别为判别网络和生成网络的参数。

生成网络的目标刚好和判别网络相反,即让判别网络将自己生成的样本判别为真实样本。
\begin{aligned} & \max _{\theta}\left(\mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log D(G(\mathbf{z}, \theta), \phi)]\right) \\=& \min _{\theta}\left(\mathbb{E}_{\mathbf{z} \sim p(\mathbf{z})}[\log (1-D(G(\mathbf{z}, \theta), \phi))]\right) \end{aligned} 以上两个目标函数等价,但在实际过程中,由于前者梯度更好,所以更常用。
生成对抗网络的算法过程如下所示:

二、DCGAN

深度卷积生成对抗网络(deep convolutional generative adversarial networks, DCGAN)是一个经典的深度生成模型。其主要优点有:(1)使用代步长的卷积(判别网络)和微步卷积(生成网络)代替池化操作,以免信息损失;(2)使用批量归一化;(3)去处全连接层,加快训练速度;(4)生成网络中,除最后一层使用tanh激活函数外,其余均采用relu;(5)判别网络中,都采用LeakyReLU激活函数。其结构如下所示:

生成网络结构 判别网络结构

假设真实分布p_{r}(\mathbf{x})和生成模型p_{\theta}(\mathbf{x})已知,则最优判别器为D^{\star}(\mathbf{x})=\frac{p_{r}(\mathbf{x})}{p_{r}(\mathbf{x})+p_{\theta}(\mathbf{x})}
代入目标函数,则有:
\begin{aligned} \mathcal{L}\left(G | D^{\star}\right) &=\mathbb{E}_{\mathbf{x} \sim p_{r}(\mathbf{x})}\left[\log D^{\star}(\mathbf{x})\right]+\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log \left(1-D^{\star}(\mathbf{x})\right)\right] \\ &=\mathbb{E}_{\mathbf{x} \sim p_{r}(\mathbf{x})}\left[\log \frac{p_{r}(\mathbf{x})}{p_{r}(\mathbf{x})+p_{\theta}(\mathbf{x})}\right]+\mathbb{E}_{\mathbf{x} \sim p_{\theta}(\mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x})}{p_{r}(\mathbf{x})+p_{\theta}(\mathbf{x})}\right] \\ &=D_{\mathrm{KL}}\left(p_{r} \| p_{a}\right)+D_{\mathrm{KL}}\left(p_{\theta} \| p_{a}\right)-2 \log 2 \\ &=2 D_{\mathrm{JS}}\left(p_{r} \| p_{\theta}\right)-2 \log 2 \end{aligned} 其中D_{\mathrm{JS}}JS散度,p_{a}(\mathbf{x})p_{r}(\mathbf{x})p_{\theta}(\mathbf{x})的平均。
在生成对抗网络中,当判断网络为最优时,生成网络的优化目标是最小化真实分布p_{r}(\mathbf{x})和生成模型p_{\theta}(\mathbf{x})之间的JS散度。然而,JS散度的一个问题是:当两个分布没有重叠时,它们之间的JS散度恒等于常数log2。对生成网络来说,目标函数关于参数的梯度为0,即梯度消失。

生成对抗网络中的梯度消失
因此,在实际训练生成对抗网络时,我们一般不会将判别网络训练到最优,只进行一步或多步梯度下降,使得生成网络的梯度依然存在。然而,判别网络 也不能太差,否则生成网络的梯度为错误的梯度。如何使得判别网络在梯度消 失和梯度错误之间取得平衡并不是一件容易的事。

三、算法实现

本例利用DCGAN算法根据CIFAR-10数据集生成相似图片。
1、加载数据集。CIFAR-10数据集之前已介绍过,在此不做详述。

import torch
import torchvision
import torchvision.transforms as transforms

dataset = torchvision.datasets.CIFAR10(root='Data', train=True, transform = transforms.ToTensor(), download = True)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size = 64, shuffle= True)

for idx,data in enumerate(dataloader):
    imgs,_ = data
    batch_size = imgs.size(0)
    print('#{} has {} images'.format(idx,batch_size))
    
    if idx % 100 == 0:
        path = 'Data/CIFAR10_batch{:03d}.png'.format(idx)
        torchvision.utils.save_image(imgs,path,normalize=True)
CIFAR-10原始图片

2、生成网络和判别网络的搭建。利用pytorch框架根据前述网络结构搭建出两个网络。

import torch.nn as nn
import torch.nn.init as init
import torch.optim

latent_size = 64
n_channel = 3
n_g_feature = 64

gnet = nn.Sequential(
        nn.ConvTranspose2d(latent_size,4*n_g_feature,kernel_size=4,bias=False),
        nn.BatchNorm2d(4*n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(4*n_g_feature,2*n_g_feature,kernel_size=4,stride=2,padding=1,bias=False),
        nn.BatchNorm2d(2*n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(2*n_g_feature,n_g_feature,kernel_size=4,stride=2,padding=1,bias=False),
        nn.BatchNorm2d(n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(n_g_feature,n_channel,kernel_size=4,stride=2,padding=1),
        nn.Tanh(),
        )

n_d_feature = 64
dnet = nn.Sequential(
        nn.Conv2d(n_channel,n_d_feature,kernel_size=4,stride=2,padding=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(n_d_feature,2*n_d_feature,kernel_size=4,stride=2,padding=1,bias=False),
        nn.BatchNorm2d(2*n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(2*n_d_feature,4*n_d_feature,kernel_size=4,stride=2,padding=1,bias=False),
        nn.BatchNorm2d(4*n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(4*n_d_feature,1,kernel_size=4),
        )

3、网络及训练参数初始化。

def weight_init(m):
    if type(m) in [nn.ConvTranspose2d,nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight,1.0,0.02)
        init.constant_(m.bias,0)

gnet.apply(weight_init)
dnet.apply(weight_init)


criterion = nn.BCEWithLogitsLoss()
goptimizer = torch.optim.Adam(gnet.parameters(),lr=0.0002,betas=(0.5,0.999))
doptimizer = torch.optim.Adam(dnet.parameters(),lr=0.0002,betas=(0.5,0.999))

batch_size = 64
noises = torch.randn(batch_size,latent_size,1,1)
epoch_num = 10

4、判别网络与生成网络的训练。对于每一批数据,先对判别网络进行训练,再对生成网络进行训练。训练判别网络时,首先考虑真实数据带来的损失,再考虑虚假数据带来的损失,然后根据这两者带来的损失,一次性对判别网络进行更新。值得一提的是,计算虚假数据损失时采用了张量成员方法detach(),终止梯度反向传播,加快训练速度,这是因为在训练判别网络时,没有必要求解损失对于生成函数权值的梯度。

for epoch in range(epoch_num):
    for idx,data in enumerate(dataloader):
        imgs,_ = data
        batch_size = imgs.size(0)
        
        #train discriminative network
        labels = torch.ones(batch_size)
        preds = dnet(imgs)
        outputs = preds.reshape(-1)
        d_loss_real = criterion(outputs,labels)
        d_mean_real = outputs.sigmoid().mean()
        
        noises = torch.randn(batch_size,latent_size,1,1)
        fake_imgs = gnet(noises)
        labels = torch.zeros(batch_size)
        fake = fake_imgs.detach()
        preds = dnet(fake)
        outputs = preds.view(-1)
        d_loss_fake = criterion(outputs,labels)
        d_mean_fake = outputs.sigmoid().mean()
        
        d_loss = d_loss_real + d_loss_fake
        dnet.zero_grad()
        d_loss.backward()
        doptimizer.step()
        
        
        #train generative network
        labels = torch.ones(batch_size)
        preds = dnet(fake_imgs)
        outputs = preds.view(-1)
        g_loss = criterion(outputs,labels)
        g_mean = outputs.sigmoid().mean()
        
        gnet.zero_grad()
        g_loss.backward()
        goptimizer.step()
        
        print('[{}/{}]'.format(epoch,epoch_num)+'[{}/{}]'.format(idx,len(dataloader)))
        print('d_loss:{:g},g_loss:{:g}'.format(d_loss,g_loss))
        print('TPR:{:g},FPR:{:g}/{:g}'.format(d_mean_real,d_mean_fake,g_mean))
        
        if idx % 100 == 0:
            fake = gnet(noises)
            path = 'Data/images_epoch{:02d}_batch{:03d}.png'.format(epoch,idx)
            torchvision.utils.save_image(fake,path,normalize=True)

以下为不同阶段的图片生成情况,可见随着循环次数增加,生成图片不断接近于真实图片。


原始噪音 循环1次第300张图片时 循环4次第200张图片时 循环7次第700张图片时 循环14次第500张图片时 循环19次第300张图片时

参考资料

[1] Vishnu Subramanian. Deep Learning with PyTorch. Packet Publishing. 2018.
[2] 邱锡鹏 著,神经网络与深度学习. https://nndl.github.io/ 2019.
[3] 肖智清 著,神经网络与PyTorch实战. 北京:机械工业出版社. 2018.
[4] 唐进民 编著,深度学习之PyTorch实战计算机视觉. 北京:电子工业出版社. 2018.
[5] Ian Goodfellow 等 著, 赵申剑等 译, 深度学习. 北京:人民邮电出版社, 2017.

洛阳城里见秋风,欲作家书意万重。——张籍《秋思》

相关文章

网友评论

      本文标题:深度学习(五):生成对抗网络

      本文链接:https://www.haomeiwen.com/subject/rpduyctx.html