图片自动上色的原理很简单,下面我们边做边讲
首先导入必要的工具
import torch as t
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from skimage import color
from skimage.io import imshow
from tqdm import tqdm_notebook
from torchvision import transforms
%matplotlib inline
加载彩色图片,这里选择了Dota2里面的圣堂刺客TA的一张宣传画,先看看彩色模式
img_rgb = Image.open("H:/COLOURING/TA.jpg").resize((256,256))
img_rgb = np.array(img_rgb)
plt.imshow(img_rgb),img_rgb.shape
彩色TA
然后是黑白模式
img_gray = np.array(Image.open("H:/COLOURING/TA.jpg").convert('L').resize((256,256)))
plt.imshow(img_gray,cmap = 'gray')
黑白TA
在计算机中,彩色图片通常以RGB模式显示(在opencv中是BGR形式),有三个通道,即图片是由三个像素矩阵叠加而成。
而黑白模式的图片只有一个通道,即只有一个像素矩阵。
出RGB模式外,还有很多图片显示模式,比如本文中要用到的lab模式,lab模式中同样有三个通道,第一个通道l是亮度通道,用来表示图片亮度,其效果与黑白图片非常相似。下面是以灰度模式展示的l通道
img_lab = color.rgb2lab(img_rgb/255)
img_lab_l = img_lab[:,:,0]
plt.imshow(img_lab_l,cmap = 'gray')
亮度TA
l通道展示效果与灰度图像无异,另外两个通道a和b是两个色彩通道,下面我们把两个色彩通道单独拎出来看看:
首先是a通道:
img_lab_a = img_lab[:,:,1]
plt.imshow(img_lab_a,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
a通道TA
很明显,色彩通道里面看不到图像的线条信息,下面再看一下b通道:
img_lab_b = img_lab[:,:,2]
plt.imshow(img_lab_b,cmap = 'gray') # matplotlib没有专门绘制ab通道的cmap,所以这里只是个示意图,真实色彩不是这样的。
b通道TA
b通道里面没了眼影的TA是真的丑~~
上色原理
介绍到这里,自动上色的原理已经很明朗了,就是以亮度层为data,ab层作为target,建立一个从亮度图像到色彩层的映射。
注意,这里的黑白图片指的是亮度层,而不是灰度图片
搭建神经网络
首先需要建立一个神经网络,这个网络的输入是图片的l层,输出是图片的ab层。
class Net(t.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = t.nn.Sequential(
t.nn.Conv2d(1,16,3,stride=2,padding=1),
t.nn.BatchNorm2d(16),
t.nn.ReLU(),
t.nn.Upsample(scale_factor=2)
)
self.conv2 = t.nn.Sequential(
t.nn.Conv2d(16,32,3,2,1),
t.nn.BatchNorm2d(32),
t.nn.ReLU(),
t.nn.Upsample(scale_factor=2)
)
self.conv3 = t.nn.Sequential(
t.nn.Conv2d(32,16,3,2,1),
t.nn.BatchNorm2d(16),
t.nn.ReLU(),
t.nn.Upsample(scale_factor=2)
)
self.conv4 = t.nn.Sequential(
t.nn.Conv2d(16,2,3,2,1),
t.nn.BatchNorm2d(2),
t.nn.ReLU(),
t.nn.Upsample(scale_factor=2)
)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x
处理数据
img_gray = img_gray[:,:,np.newaxis]
img_lab_l = img_lab_l[:,:,np.newaxis]
img_gray.shape,img_lab_l.shape
((256, 256, 1), (256, 256, 1))
x_train = img_lab_l
y_train = img_lab[:,:,1:3]
y_train /= 128
transform = transforms.Compose([
transforms.ToTensor(),
])
PIL中image对象是(H,W,C)形状,而Pytorch中的图像tensor是(C,H,W)形状,需要进行转换
x_train,y_train = transform(x_train),transform(y_train)
x_train,y_train = x_train.float(),y_train.float()
x_train,y_train = x_train.view(-1,1,256,256),y_train.view(-1,2,256,256)
x_train.shape,y_train.shape
(torch.Size([1, 1, 256, 256]), torch.Size([1, 2, 256, 256]))
训练模型
net = Net()
net
Net(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Upsample(scale_factor=2, mode=nearest)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Upsample(scale_factor=2, mode=nearest)
)
(conv3): Sequential(
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Upsample(scale_factor=2, mode=nearest)
)
(conv4): Sequential(
(0): Conv2d(16, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Upsample(scale_factor=2, mode=nearest)
)
)
EPOCHS = 500
LR = 0.01
criterion = t.nn.MSELoss()
optimizer = t.optim.Adam(net.parameters(),lr=LR,weight_decay=0.0)
for epoch in tqdm_notebook(range(EPOCHS)):
index=0
if epoch % 100 == 0:
for param_group in optimizer.param_groups:
LR = LR * 0.9
param_group['lr'] = LR
prediction = net.forward(x_train)
loss = criterion(prediction,y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss
还原并显示图像
net.eval()
prediction = net.forward(x_train)
prediction *= 128
prediction = prediction[0].data.numpy()
x_train = x_train[0].data.numpy()
x_train.shape,prediction.shape
result = np.zeros((256,256,3))
result[:,:,0] = x_train[0]
result[:,:,1] = prediction[0]
result[:,:,2] = prediction[1]
result_rgb = color.lab2rgb(result)
plt.imshow(np.array(result_rgb))
上色TA
500个EPOCHS后,图片已经有点样子了,迭代更多次数之后就能够达到原图的效果了。
需要源代码的可以私信我~
对机器学习感兴趣的朋友可以加群:
机器学习-菜鸡互啄
网友评论