美文网首页Python_图像处理
Python将头像照片转换为漫画,采用GAN深度学习,无噪点

Python将头像照片转换为漫画,采用GAN深度学习,无噪点

作者: 程序员小西 | 来源:发表于2022-04-07 16:56 被阅读0次

    传统的照片转漫画,使用边缘检测、双边滤波器和降采样,得到图像如下,可以看到,噪点很多,有些关键线条也没有展现出来。

    本次采用GAN,GAN网络使用的方法是根据图像对去不断地学习,如输入图像1和对应已有的漫画B,GAN网络从图片1中获取关键特征,不停地生成一张图像C,当C与B的差值很小时停止,当有很多这样地图像对时,我们就有了一个模型。输入一张图像,就可以生成一张对应地漫画图像,我这次使用的GAN(White-box Cartoon)生成。生成效果:

    图片.png 图片.png

    原始图片大小建议为256*256像素

    完整程序代码

    
    import os
    import cv2
    import torch
    import numpy as np
    import torch.nn as nn
    
    class ResBlock(nn.Module):
        def __init__(self, num_channel):
            super(ResBlock, self).__init__()
            self.conv_layer = nn.Sequential(
            nn.Conv2d(num_channel, num_channel, 3, 1, 1),
            nn.BatchNorm2d(num_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_channel, num_channel, 3, 1, 1),
            nn.BatchNorm2d(num_channel))
            self.activation = nn.ReLU(inplace=True)
            
            def forward(self, inputs):
                output = self.conv_layer(inputs)
                output = self.activation(output + inputs)
                return output
                
                
                class DownBlock(nn.Module):
                    def __init__(self, in_channel, out_channel):
                        super(DownBlock, self).__init__()
                        self.conv_layer = nn.Sequential(
                        nn.Conv2d(in_channel, out_channel, 3, 2, 1),
                        nn.BatchNorm2d(out_channel),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(out_channel, out_channel, 3, 1, 1),
                        nn.BatchNorm2d(out_channel),
                        nn.ReLU(inplace=True))
                        
                        
                        def forward(self, inputs):
                            output = self.conv_layer(inputs)
                            return output
                            
                            
                            class UpBlock(nn.Module):
                                def __init__(self, in_channel, out_channel, is_last=False):
                                    super(UpBlock, self).__init__()
                                    self.is_last = is_last
                                    self.conv_layer = nn.Sequential(
                                    nn.Conv2d(in_channel, in_channel, 3, 1, 1),
                                    nn.BatchNorm2d(in_channel),
                                    nn.ReLU(inplace=True),
                                    nn.Upsample(scale_factor=2),
                                    nn.Conv2d(in_channel, out_channel, 3, 1, 1))
                                    self.act = nn.Sequential(
                                    nn.BatchNorm2d(out_channel),
                                    nn.ReLU(inplace=True))
                                    self.last_act = nn.Tanh()
                                    
                                    
                                    def forward(self, inputs):
                                        output = self.conv_layer(inputs)
                                        if self.is_last:
                                            output = self.last_act(output)
                                        else:
                                            output = self.act(output)
                                            return output
                                            
                                            
                                            
                                            class SimpleGenerator(nn.Module):
                                                def __init__(self, num_channel=32, num_blocks=4):
                                                    super(SimpleGenerator, self).__init__()
                                                    self.down1 = DownBlock(3, num_channel)
                                                    self.down2 = DownBlock(num_channel, num_channel*2)
                                                    self.down3 = DownBlock(num_channel*2, num_channel*3)
                                                    self.down4 = DownBlock(num_channel*3, num_channel*4)
                                                    res_blocks = [ResBlock(num_channel*4)]*num_blocks
                                                    self.res_blocks = nn.Sequential(*res_blocks)
                                                    self.up1 = UpBlock(num_channel*4, num_channel*3)
                                                    self.up2 = UpBlock(num_channel*3, num_channel*2)
                                                    self.up3 = UpBlock(num_channel*2, num_channel)
                                                    self.up4 = UpBlock(num_channel, 3, is_last=True)
                                                    
                                                    def forward(self, inputs):
                                                        down1 = self.down1(inputs)
                                                        down2 = self.down2(down1)
                                                        down3 = self.down3(down2)
                                                        down4 = self.down4(down3)
                                                        down4 = self.res_blocks(down4)
                                                        up1 = self.up1(down4)
                                                        up2 = self.up2(up1+down3)
                                                        up3 = self.up3(up2+down2)
                                                        up4 = self.up4(up3+down1)
                                                        return up4
                                                        weight = torch.load('weight.pth', map_location='cpu')
                                                        model = SimpleGenerator()
                                                        model.load_state_dict(weight)
                                                        model.eval()
                                                        
                                                        img = cv2.imread(r'input.jpg')
                                                        
                                                        image = img/127.5 - 1
                                                        image = image.transpose(2, 0, 1)
                                                        image = torch.tensor(image).unsqueeze(0)
                                                        output = model(image.float())
                                                        output = output.squeeze(0).detach().numpy()
                                                        output = output.transpose(1, 2, 0)
                                                        output = (output + 1) * 127.5
                                                        output = np.clip(output, 0, 255).astype(np.uint8)
                                                        cv2.imwrite('output.jpg', output)### unterminated keywords
    

    需要完整源码,可以私信。

    相关文章

      网友评论

        本文标题:Python将头像照片转换为漫画,采用GAN深度学习,无噪点

        本文链接:https://www.haomeiwen.com/subject/oslssrtx.html