美文网首页工具癖Pytorch与深度学习个人专题
Pytorch实践(二)——老旧照片恢复器——图片AI自动上色(

Pytorch实践(二)——老旧照片恢复器——图片AI自动上色(

作者: dalalaa | 来源:发表于2018-07-29 20:39 被阅读16次

    图片自动上色的原理很简单,下面我们边做边讲

    首先导入必要的工具

    import torch as t
    from PIL import Image
    import numpy as np 
    import matplotlib.pyplot as plt 
    from skimage import color
    from skimage.io import imshow
    from tqdm import tqdm_notebook
    from torchvision import transforms
    %matplotlib inline
    

    加载彩色图片,这里选择了Dota2里面的圣堂刺客TA的一张宣传画,先看看彩色模式

    img_rgb = Image.open("H:/COLOURING/TA.jpg").resize((256,256))
    img_rgb = np.array(img_rgb)
    plt.imshow(img_rgb),img_rgb.shape
    
    彩色TA

    然后是黑白模式

    img_gray = np.array(Image.open("H:/COLOURING/TA.jpg").convert('L').resize((256,256)))
    plt.imshow(img_gray,cmap = 'gray')
    
    黑白TA

    在计算机中,彩色图片通常以RGB模式显示(在opencv中是BGR形式),有三个通道,即图片是由三个像素矩阵叠加而成。

    而黑白模式的图片只有一个通道,即只有一个像素矩阵。

    出RGB模式外,还有很多图片显示模式,比如本文中要用到的lab模式,lab模式中同样有三个通道,第一个通道l是亮度通道,用来表示图片亮度,其效果与黑白图片非常相似。下面是以灰度模式展示的l通道

    img_lab = color.rgb2lab(img_rgb/255)
    img_lab_l = img_lab[:,:,0]
    plt.imshow(img_lab_l,cmap = 'gray')
    
    亮度TA

    l通道展示效果与灰度图像无异,另外两个通道a和b是两个色彩通道,下面我们把两个色彩通道单独拎出来看看:

    首先是a通道:

    img_lab_a = img_lab[:,:,1]
    plt.imshow(img_lab_a,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
    
    a通道TA

    很明显,色彩通道里面看不到图像的线条信息,下面再看一下b通道:

    img_lab_b = img_lab[:,:,2]
    plt.imshow(img_lab_b,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
    
    b通道TA

    b通道里面没了眼影的TA是真的丑~~

    上色原理

    介绍到这里,自动上色的原理已经很明朗了,就是以亮度层为data,ab层作为target,建立一个从亮度图像到色彩层的映射。

    注意,这里的黑白图片指的是亮度层,而不是灰度图片

    搭建神经网络

    首先需要建立一个神经网络,这个网络的输入是图片的l层,输出是图片的ab层。

    class Net(t.nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.conv1 = t.nn.Sequential(
                t.nn.Conv2d(1,16,3,stride=2,padding=1),
                t.nn.BatchNorm2d(16),
                t.nn.ReLU(),
                t.nn.Upsample(scale_factor=2)
            )
            self.conv2 = t.nn.Sequential(
                t.nn.Conv2d(16,32,3,2,1),
                t.nn.BatchNorm2d(32),
                t.nn.ReLU(),
                t.nn.Upsample(scale_factor=2)
            )
            self.conv3 = t.nn.Sequential(
                t.nn.Conv2d(32,16,3,2,1),
                t.nn.BatchNorm2d(16),
                t.nn.ReLU(),
                t.nn.Upsample(scale_factor=2)
            )
            self.conv4 = t.nn.Sequential(
                t.nn.Conv2d(16,2,3,2,1),
                t.nn.BatchNorm2d(2),
                t.nn.ReLU(),
                t.nn.Upsample(scale_factor=2)
            )
            
        def forward(self,x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.conv4(x)
            return x
    

    处理数据

    img_gray = img_gray[:,:,np.newaxis]
    img_lab_l = img_lab_l[:,:,np.newaxis]
    img_gray.shape,img_lab_l.shape
    
    ((256, 256, 1), (256, 256, 1))
    
    x_train = img_lab_l
    
    y_train = img_lab[:,:,1:3]
    y_train /= 128
    
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    

    PIL中image对象是(H,W,C)形状,而Pytorch中的图像tensor是(C,H,W)形状,需要进行转换

    x_train,y_train = transform(x_train),transform(y_train)
    x_train,y_train = x_train.float(),y_train.float()
    
    x_train,y_train = x_train.view(-1,1,256,256),y_train.view(-1,2,256,256)
    
    x_train.shape,y_train.shape
    
    (torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 256, 256]))
    

    训练模型

    net = Net()
    net
    
    Net(
      (conv1): Sequential(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Upsample(scale_factor=2, mode=nearest)
      )
      (conv2): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Upsample(scale_factor=2, mode=nearest)
      )
      (conv3): Sequential(
        (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Upsample(scale_factor=2, mode=nearest)
      )
      (conv4): Sequential(
        (0): Conv2d(16, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Upsample(scale_factor=2, mode=nearest)
      )
    )
    
    EPOCHS = 500
    LR = 0.01
    criterion = t.nn.MSELoss()
    optimizer = t.optim.Adam(net.parameters(),lr=LR,weight_decay=0.0)
    
    for epoch in tqdm_notebook(range(EPOCHS)):
        index=0
        if epoch % 100 == 0:
            for param_group in optimizer.param_groups:
                LR = LR * 0.9
                param_group['lr'] = LR
        prediction = net.forward(x_train)
        loss = criterion(prediction,y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss
    

    还原并显示图像

    net.eval()
    prediction = net.forward(x_train)
    prediction *= 128
    
    prediction = prediction[0].data.numpy()
    x_train = x_train[0].data.numpy()
    x_train.shape,prediction.shape
    
    result = np.zeros((256,256,3))
    result[:,:,0] = x_train[0]
    result[:,:,1] = prediction[0]
    result[:,:,2] = prediction[1]
    
    result_rgb = color.lab2rgb(result)
    
    plt.imshow(np.array(result_rgb))
    
    上色TA

    500个EPOCHS后,图片已经有点样子了,迭代更多次数之后就能够达到原图的效果了。

    需要源代码的可以私信我~

    对机器学习感兴趣的朋友可以加群:

    机器学习-菜鸡互啄

    相关文章

      网友评论

        本文标题:Pytorch实践(二)——老旧照片恢复器——图片AI自动上色(

        本文链接:https://www.haomeiwen.com/subject/gretvftx.html