一、简介
https://github.com/naiq/PN_GAN是论文Pose-Normalized Image Generation for Person Re-identification的实现代码
二、代码梳理
2.1 网络构建
2.1.1 生成网络
参数:ngf(channel数相关,可理解为特征图个数), num_resblock(残差块个数,默认为9个)
拼接原图片和pose图片:
x = torch.cat((im, pose), dim=1)
下采样部分(卷积):
self.conv1 = nn.Sequential(OrderedDict([
('pad', nn.ReflectionPad2d(3)),
('conv', nn.Conv2d(6, ngf, kernel_size=7, stride=1, padding=0, bias=True)),
('bn', nn.InstanceNorm2d(ngf)),
('relu', nn.ReLU(inplace=True)),
]))
self.conv2 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ngf*2)),
('relu', nn.ReLU(inplace=True)),
]))
self.conv3 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ngf*4)),
('relu', nn.ReLU(inplace=True)),
]))
残差块部分,包括num_resblock个ResBlock,ResBlock如下:
# ncf 为ngf*4
...
self.conv1 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ncf, ncf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
('bn', nn.InstanceNorm2d(ncf)),
('relu', nn.ReLU(inplace=True)),
]))
self.conv2 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ncf, ncf, kernel_size=3, stride=1, padding=1, bias=use_bias)),
('bn', nn.InstanceNorm2d(ncf)),
]))
....
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = out + x
out = self.relu(out)
return out
上采样部分(解卷积):
self.deconv3 = nn.Sequential(OrderedDict([
('deconv', nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ngf*2)),
('relu', nn.ReLU(True))
]))
self.deconv2 = nn.Sequential(OrderedDict([
('deconv', nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ngf)),
('relu', nn.ReLU(True))
]))
self.deconv1 = nn.Sequential(OrderedDict([
('pad', nn.ReflectionPad2d(3)),
('conv', nn.Conv2d(ngf, 3, kernel_size=7, stride=1, padding=0, bias=False)),
('tanh', nn.Tanh())
]))
2.1.2 分类网络( Patch_Discriminator)
参数:ndf(channel数相关,可理解为特征图个数)
下采样部分:
self.conv1 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False)),
('relu', nn.LeakyReLU(0.2, True))
]))
self.conv2 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ndf*2)),
('relu', nn.LeakyReLU(0.2, True))
]))
self.conv3 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, bias=True)),
('bn', nn.InstanceNorm2d(ndf*4)),
('relu', nn.LeakyReLU(0.2, True))
]))
self.conv4 = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=1, padding=0, bias=True)),
('bn', nn.InstanceNorm2d(ndf*8)),
('relu', nn.LeakyReLU(0.2, True))
]))
dis层,即为最后一层卷积:
self.dis = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=0, bias=False)),
]))
dis.squeeze()为网络最后输出,dis的channel数为1,又经过squeeze,即将dis弄成了batch_sizehw的尺寸。
2.2 优化器
生成网络的优化器为Adam优化器,lr=cfg.TRAIN.LR(默认为0.0002), betas=(0.5, 0.999)。
分类网络的优化器相同
生成网络和分类网络的学习率调整策略均为:
lr_policy = lambda epoch: (1 - 1 * max(0, epoch-cfg.TRAIN.LR_DECAY) / cfg.TRAIN.LR_DECAY)
即第epoch的学习率为原始lr * (1 - 1 * max(0, epoch-cfg.TRAIN.LR_DECAY) / cfg.TRAIN.LR_DECAY)
2.3 训练
生成网络G采用的损失函数为torch.nn.MSELoss(),分类网络D的损失函数为 torch.nn.L1Loss(),分别记做criterionGAN和criterionIdt
2.3.1 数据处理
经过一系列数据预处理之后,得到src_img(原图片),tgt_img(目标图片),pose。
2.3.2 生成图片
根据原图片src_img和姿态图片pose生成原图片中行人姿态变为pose的新图片fake_img
fake_img = netG(src_img, pose)
2.3.3 更新生成器
D_fake_img = netD(fake_img)
G_loss = criterionGAN(D_fake_img, torch.ones_like(D_fake_img))
idt_loss = criterionIdt(fake_img, tgt_img) * cfg.TRAIN.lambda_idt
fake_img经过分类网络D得到D_fake_img,计算Lgan和Ll1(与论文对应):
G_loss对应Lgan:
QQ图片20190918172048.png
G_loss使用mse计算的,跟原文不大一致,但反正就是跟分类器相关,也就是生成器的理想状态是让生成图片被分类器识别成原图,也就是分类结果应为 torch.ones_like(D_fake_img)。
而idt_loss对应的就是:
L1loss.png
这是用了l1 loss计算的,跟原文一致
生成网络的优化目标为:
lgp.png
对应的代码为:
loss_G = G_loss + idt_loss
注意公式中lamda1就是 cfg.TRAIN.lambda_idt。
最后就是更新网络了:
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
2.3.4 更新分类器
首先是损失函数的计算:
D_fake_img = netD(fake_img.detach())
D_real_img = netD(src_img)
D_fake_loss = criterionGAN(D_fake_img, torch.zeros_like(D_fake_img))
D_real_loss = criterionGAN(D_real_img, torch.ones_like(D_real_img))
loss_D = D_fake_loss + D_real_loss
D_fake_img就是分类器对生成网络生成的图片的分类,D_real_img就是分类器对原图片的分类。原文中,分类器的损失函数公式为:
Ldp.png
但实际上实现的时候,就是计算分类的误差,也就是D_fake_img与理想值(全为0)和D_real_img与理想值(全为1)的平方和。
最后也是更新网络:
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
网友评论