美文网首页
Pose-Transfer代码阅读笔记

Pose-Transfer代码阅读笔记

作者: hdychi | 来源:发表于2019-09-26 14:24 被阅读0次

    一、简介

    笔者阅读的Pose-Transfer代码为https://github.com/tengteng95/Pose-Transfer的Pytorch_v1.0分支,适应于pytorch1.x的版本。以下讲的流程为ReadMe中给出的运行参数的情况下的流程,它是论文Progressive Pose Attention for Person Image Generation in CVPR19 (Oral)的代码。

    二、网络结构

    神经网络相关的代码阅读入口为models/PATN.py中的class TransferModel

    2.1 生成网络netG

    定义了生成网络netG的的代码为:

    
     netG = PATNetwork(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                                               n_blocks=9, gpu_ids=gpu_ids, n_downsampling=n_downsampling)
    

    参数:input_nc(输入channel数),output_nc(输出channel数),ngf(channel数相关,可理解为特征图个数,生成网络基于此数的倍数进行channel变化), n_blocks(PATB个数),n_downsampling(下采样卷积个数)
    网络结构相关代码如下:
    先看PATNetwork的forward前向计算代码:

     def forward(self, input): # x from stream 1 and stream 2
            # here x should be a tuple
            x1, x2 = input
            # down_sample
            x1 = self.stream1_down(x1)
            x2 = self.stream2_down(x2)
            # att_block
            for model in self.att:
                x1, x2, _ = model(x1, x2)
    
            # up_sample
            x1 = self.stream1_up(x1)
    
            return x1
    

    也就是生成网络大致可分为下采样部分、att_block部分、上采样部分。att_block的上面分支的最后输出经过上采样为最终结果。

    2.1.1 下采样部分

    1.先是Padding层:

    model_stream1_down = [nn.ReflectionPad2d(3),
                        nn.Conv2d(self.input_nc_s1, ngf, kernel_size=7, padding=0,
                               bias=use_bias),
                        norm_layer(ngf),
                        nn.ReLU(True)]
    
     model_stream2_down = [nn.ReflectionPad2d(3),
                        nn.Conv2d(self.input_nc_s2, ngf, kernel_size=7, padding=0,
                               bias=use_bias),
                        norm_layer(ngf),
                        nn.ReLU(True)]
    

    2.n_downsampling个下采样卷积层:

            for i in range(n_downsampling):
                mult = 2**i
                model_stream1_down += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                                norm_layer(ngf * mult * 2),
                                nn.ReLU(True)]
                model_stream2_down += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                                norm_layer(ngf * mult * 2),
                                nn.ReLU(True)]
    

    3.链接起层,赋值

            self.stream1_down = nn.Sequential(*model_stream1_down)
            self.stream2_down = nn.Sequential(*model_stream2_down)
    

    2.1.2 att block部分,即对应论文中的PATB,即Pose-Attentional Transfer Network。

    贴一张论文里的图:


    pose-transfer网络结构.png

    后文中讲的PATB的第一分支就是上面的分支,第二分支就是下面的分支。

    mult = 2**n_downsampling
            cated_stream2 = [True for i in range(n_blocks)]
            cated_stream2[0] = False
            attBlock = nn.ModuleList()
            for i in range(n_blocks):
                attBlock.append(PATBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,            use_dropout=use_dropout, use_bias=use_bias, cated_stream2=cated_stream2[i]))
    

    也就是n_blocks个PATB块,一个PATB块构成为:
    首先是变量定义,conv_blocks用于存储各个层:

     conv_block = []
     p = 0
    

    前向计算forward函数代码如下:

        def forward(self, x1, x2):
            x1_out = self.conv_block_stream1(x1)
            x2_out = self.conv_block_stream2(x2)
            # att = F.sigmoid(x2_out)
            att = torch.sigmoid(x2_out)
    
            x1_out = x1_out * att
            out = x1 + x1_out # residual connection
    
            # stream2 receive feedback from stream1
            x2_out = torch.cat((x2_out, out), 1)
            return out, x2_out, x1_out
    

    可以看出是两个输入,两个输出,结合论文图示看更易理解。在整个生成网络的forward函数中,取out、x2_out进行下一步运算。
    其中,conv_block_stream1与conv_block_stream2的构建代码为:

            self.conv_block_stream1 = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, cal_att=False)
            self.conv_block_stream2 = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, cal_att=True, cated_stream2=cated_stream2)
    

    conv_block_stream1与conv_block_stream2结构具体的网络结构如下:
    1.Padding

            if padding_type == 'reflect':
                conv_block += [nn.ReflectionPad2d(1)]
            elif padding_type == 'replicate':
                conv_block += [nn.ReplicationPad2d(1)]
            elif padding_type == 'zero':
                p = 1
    

    2.Normalize

            if cated_stream2:
                conv_block += [nn.Conv2d(dim*2, dim*2, kernel_size=3, padding=p, bias=use_bias),
                           norm_layer(dim*2),
                           nn.ReLU(True)]
            else:
                conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                               norm_layer(dim),
                               nn.ReLU(True)]
    

    其中cated_stream2,在第一个分支为False,在第二个分支第一个PATB中为False,在第二个及以后中为True,因为PATB的第二个分支的最后输出为为第一个分支卷积结果和第二个分支卷积结果的拼接(x2_out = torch.cat((x2_out, out), 1))
    3.dropout层(可选)

            if use_dropout:
                conv_block += [nn.Dropout(0.5)]
    

    4.再次Padding
    代码与第一次Padding相同
    5.卷积层

            if cal_att:
                if cated_stream2:
                    conv_block += [nn.Conv2d(dim*2, dim, kernel_size=3, padding=p, bias=use_bias)]
                else:
                    conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
            else:
                conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                           norm_layer(dim)]
    

    cal_att为False时是第一分支,cal_att为第二分支。第一分支输入输出channel均为dim,第二个分支则需要把dim*2的输入channel转成dim的输出channel,方便与第一分支进行拼接。

    2.1.3 上采样部分

            model_stream1_up = []
            for i in range(n_downsampling):
                mult = 2**(n_downsampling - i)
                model_stream1_up += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=2,
                                             padding=1, output_padding=1,
                                             bias=use_bias),
                                norm_layer(int(ngf * mult / 2)),
                                nn.ReLU(True)]
    

    大致就是n_downsampling个反卷积层的上采样。

    2.2 分类网络

    有两个分类网络,分别为netD_PB和netD_PP。netD_PB用于评判输出图片Pg和目标姿态St的的匹配程度(英文原文:how well Pg align with the target pose St(shape consistency).),netD_PP用于评判输出图片Pg是否包含输入图片Pc中的同一个人(英文原文:judge how likely Pg contains the same person in Pc (appearance consistency))
    netD_PB和net_PP结构相同:

                use_sigmoid = opt.no_lsgan
                if opt.with_D_PB:
                    self.netD_PB = networks.define_D(opt.P_input_nc+opt.BP_input_nc, opt.ndf,
                                                opt.which_model_netD,
                                                opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                                not opt.no_dropout_D,
                                                n_downsampling = opt.D_n_downsampling)
    
                if opt.with_D_PP:
                    self.netD_PP = networks.define_D(opt.P_input_nc+opt.P_input_nc, opt.ndf,
                                                opt.which_model_netD,
                                                opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                                not opt.no_dropout_D,
                                                n_downsampling = opt.D_n_downsampling)
    

    参数:input_nc(输入channel数),output_nv(输出channel数),ndf(channel数相关,可理解为特征图个数,分类网络基于此数的倍数进行channel变化),which_model_netD(分类器的基础网络,如resnet),n_layers_D(分类器中的block个数),norm(instance normalization or batch normalization),n_downsampling(下采样卷积个数)
    代码中define_D提供的是ResnetDiscriminator。首先看ResnetDiscriminator的forward函数:

        def forward(self, input):
            if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
                return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
            else:
                return self.model(input)
    

    就是只有一个self.model跑输入得到输出即可。
    self.model的结构为:
    1.Padding

    model = [nn.ReflectionPad2d(3),
                     nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                               bias=use_bias),
                     norm_layer(ngf),
                     nn.ReLU(True)]
    

    2.下采样部分

     if n_downsampling <= 2:
                for i in range(n_downsampling):
                    mult = 2 ** i
                    model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                        stride=2, padding=1, bias=use_bias),
                              norm_layer(ngf * mult * 2),
                              nn.ReLU(True)]
            elif n_downsampling == 3:
                mult = 2 ** 0
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True)]
                mult = 2 ** 1
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True)]
                mult = 2 ** 2
                model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult),
                          nn.ReLU(True)]
    
            if n_downsampling <= 2:
                mult = 2 ** n_downsampling
            else:
                mult = 4
    

    就是凑出n_downsampling个下采样卷积层
    3.残差块部分

            for i in range(n_blocks):
                model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                      use_bias=use_bias)]
    

    ResnetBlock就是resnet中的Identity Block,不再展开叙述了
    4.sigmoid层

            if use_sigmoid:
                model += [nn.Sigmoid()]
    

    三、损失函数计算

    在train.py中,调用的model.optimize_parameters()调整网络权重函数具体代码如下:

     # forward
            self.forward()
    
            self.optimizer_G.zero_grad()
            self.backward_G()
            self.optimizer_G.step()
    
            # D_P
            if self.opt.with_D_PP:
                for i in range(self.opt.DG_ratio):
                    self.optimizer_D_PP.zero_grad()
                    self.backward_D_PP()
                    self.optimizer_D_PP.step()
    
            # D_BP
            if self.opt.with_D_PB:
                for i in range(self.opt.DG_ratio):
                    self.optimizer_D_PB.zero_grad()
                    self.backward_D_PB()
                    self.optimizer_D_PB.step()
    

    其中的forward函数为:

        def forward(self):
            G_input = [self.input_P1,
                       torch.cat((self.input_BP1, self.input_BP2), 1)]
            self.fake_p2 = self.netG(G_input)
    

    总结一下就是分为以下几步:

    3.1 前向计算生成网络G得到生成图片self.fake_p2

    3.2 给G网络调参,即向后传播

    backward_G()代码为:

        def backward_G(self):
            if self.opt.with_D_PB:
                pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2), 1))
                self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True)
    
            if self.opt.with_D_PP:
                pred_fake_PP = self.netD_PP(torch.cat((self.fake_p2, self.input_P1), 1))
                self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True)
    
            # L1 loss
            if self.opt.L1_type == 'l1_plus_perL1' :
                losses = self.criterionL1(self.fake_p2, self.input_P2)
                self.loss_G_L1 = losses[0]
                self.loss_originL1 = losses[1].item()
                self.loss_perceptual = losses[2].item()
            else:
                self.loss_G_L1 = self.criterionL1(self.fake_p2, self.input_P2) * self.opt.lambda_A
    
    
            pair_L1loss = self.loss_G_L1
            if self.opt.with_D_PB:
                pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN
                if self.opt.with_D_PP:
                    pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN
                    pair_GANloss = pair_GANloss / 2
            else:
                if self.opt.with_D_PP:
                    pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN
    
            if self.opt.with_D_PB or self.opt.with_D_PP:
                pair_loss = pair_L1loss + pair_GANloss
            else:
                pair_loss = pair_L1loss
    
            pair_loss.backward()
    
            self.pair_L1loss = pair_L1loss.item()
            if self.opt.with_D_PB or self.opt.with_D_PP:
                self.pair_GANloss = pair_GANloss.item()
    

    文字表述就是:
    1.分别计算分类器生成的D_PP,D_PB(链上文2.2)的分类损失(生成目标为了混淆分类器,理想值应为True),分别记做loss_G_GAN_PB、loss_G_GAN_PP

    2.(l1_plus_perL1)将目标图片与生成图片做l1_plus_perL1的损失函数计算。来看L1_plus_perceptualLoss的具体代码:
    首先是该loss层的forward函数:

        def forward(self, inputs, targets):
            if self.lambda_L1 == 0 and self.lambda_perceptual == 0:
                return torch.zeros(1).cuda(), torch.zeros(1), torch.zeros(1)
            # normal L1
            loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1
    
            # perceptual L1
            mean = torch.FloatTensor(3)
            mean[0] = 0.485
            mean[1] = 0.456
            mean[2] = 0.406
            mean = mean.resize(1, 3, 1, 1).cuda()
    
            std = torch.FloatTensor(3)
            std[0] = 0.229
            std[1] = 0.224
            std[2] = 0.225
            std = std.resize(1, 3, 1, 1).cuda()
    
            fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1]
            fake_p2_norm = (fake_p2_norm - mean)/std
    
            input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1]
            input_p2_norm = (input_p2_norm - mean)/std
    
    
            fake_p2_norm = self.vgg_submodel(fake_p2_norm)
            input_p2_norm = self.vgg_submodel(input_p2_norm)
            input_p2_norm_no_grad = input_p2_norm.detach()
    
            if self.percep_is_l1 == 1:
                # use l1 for perceptual loss
                loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual
            else:
                # use l2 for perceptual loss
                loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual
    
            loss = loss_l1 + loss_perceptual
    
            return loss, loss_l1, loss_perceptual
    

    l1_plus_perL1包括两种loss:一种是普通的L1 loss,即直接将input和target做L1 loss,记做loss_l1。
    另一种 loss_perceptual的计算过程如下:(1)对input和target分别做normalize,其实就是将他们从[-1,1]的范围变到[0,1],然后减去mean再除以方差std(2)将normalize后的input和target送给vgg网络得到输出fake_p2_norm,和input_p2_norm_no_grad(3)将两个输出做L1 loss得到loss_perceptual
    loss_perceptual是为了让图片更加平滑和自然,引入论文原文:


    L1_percetual.png

    将两种loss相加就得到了最后的loss。
    回到backward_G(),三种loss分别记为loss,loss_originL1,loss_perceptual
    3.计算总loss
    链接原文公式:


    full_loss.png
    losscombl1.png
    在上面的代码中,4式中的α为2,也就是Lcomb除以2之后加上Lgan为总loss。
    最后调用总loss.backward()跟新参数

    3.3给两个D网络调参(链上文2.2节)

    在上面更新了一次G网络之后,更新DG_ratio次分类网络D_PP和D_PB

    3.3.1 给D_PP网络调参

            if self.opt.with_D_PP:
                for i in range(self.opt.DG_ratio):
                    self.optimizer_D_PP.zero_grad()
                    self.backward_D_PP()
                    self.optimizer_D_PP.step()
    
        def backward_D_PP(self):
            real_PP = torch.cat((self.input_P2, self.input_P1), 1)
            # fake_PP = self.fake_PP_pool.query(torch.cat((self.fake_p2, self.input_P1), 1))
            fake_PP = self.fake_PP_pool.query( torch.cat((self.fake_p2, self.input_P1), 1).data )
            loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP)
            self.loss_D_PP = loss_D_PP.item()
       def backward_D_basic(self, netD, real, fake):
            # Real
            pred_real = netD(real)
            loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN
            # Fake
            pred_fake = netD(fake.detach())
            loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN
            # Combined loss
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            # backward
            loss_D.backward()
            return loss_D
    

    总结步骤如下:
    1、将生成图片fake_p2和输入原图片input_P1传给fake_PP_pool.query函数,这个query函数的代码如下:

        def query(self, images):
            if self.pool_size == 0:
                return Variable(images)
            return_images = []
            for image in images:
                image = torch.unsqueeze(image, 0)
                if self.num_imgs < self.pool_size:
                    self.num_imgs = self.num_imgs + 1
                    self.images.append(image)
                    return_images.append(image)
                else:
                    p = random.uniform(0, 1)
                    if p > 0.5:
                        random_id = random.randint(0, self.pool_size-1)
                        tmp = self.images[random_id].clone()
                        self.images[random_id] = image
                        return_images.append(tmp)
                    else:
                        return_images.append(image)
            return_images = Variable(torch.cat(return_images, 0))
            return return_images
    

    这在干啥咱也不知道咱也不敢问,根据默认的配置的话跑的话是不超过50张图时将fake_p2和input_p1拼接起来返回,超过了就是取之前训练的前49张的某张图片来跟当前图片交换。
    2、将输入原图片和目标图片拼接起来得到real_PP,拿D_PP网络去预测real_PP,计算预测结果与理想结果(TRUE)之间的loss,记为loss_D_real。拿D_PP网络去预测fake_PP,计算预测结果与理想结果(FALSE)之间的loss,记为loss_D_fake。这里提醒一下D_PP网络用于预测两张图是否包含同一个人。
    3.总loss loss_D= (loss_D_real + loss_D_fake) * 0.5,loss_D.backward()更新参数

    3.3.2给D_PB网络调参

     def backward_D_PB(self):
            real_PB = torch.cat((self.input_P2, self.input_BP2), 1)
            # fake_PB = self.fake_PB_pool.query(torch.cat((self.fake_p2, self.input_BP2), 1))
            fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2), 1).data )
            loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB)
            self.loss_D_PB = loss_D_PB.item()
    

    跟D_PP类似,只不过real_PB拼接的是目标图片和目标姿势,fake_PB拼接的是生成图片和目标姿势。提醒一下D_PB用于判断图中的人的姿势是否为目标姿势。

    四、总结

    本文主要讲了训练逻辑,笔者觉得弄懂了训练代码,看测试代码就简单多了,就不再在文里分析了。

    相关文章

      网友评论

          本文标题:Pose-Transfer代码阅读笔记

          本文链接:https://www.haomeiwen.com/subject/hbtwuctx.html