TL;DR:用multiprocessing库解决python单线程处理大量图片缓慢的问题。
最近想试试HSI色彩空间的图片对卷积网络有没有帮助,就在每次加载数据的时候对每张图片做RGB到HSI的色彩空间变换。跑了几个epoch之后寻思着不对头,网络训练速度比原来慢了不少。这应该是因为数据预处理太占用CPU,感觉很不爽,于是想把整个ImageNet数据集提前处理好存下来,一劳永逸。
于是简单用python写了个脚本,遍历数据集,然后每张图片做好变换后按照原来的目录结构保存到新的根目录下。
实现很简单,但是跑了下一看,整个数据集跑完一遍竟然要17个小时......显然是因为python的GIL而无法充分利用CPU的多核性能。解决思路自然是利用真“多线程”来让程序跑起来。
python的thread是假线程,适合用在IO密集型的场景,对这种计算密集型的任务毫无帮助,而另一个multiprocessing自然就是解决方案了。运行机制简单地说就是产生一个进程池pool,pool提供一个map接口,把处理数据的函数接口和待处理的数据迭代器丢进去,进程池会自动分配多个进程执行,达到多进程的目的。当然multiprocessing库不止这么简单,还有更复杂的用法,这里并不需要所以不再深入。
最后附上代码:
import os
import tqdm
import itertools
import numpy as np
import multiprocessing as mp
from PIL import Image
def rgb2hsi(rgb):
rgb /= 255.
r, g, b = list(map(np.squeeze, np.split(rgb, 3, 2)))
hsi = np.zeros_like(rgb)
theta = np.arccos(((r - g) + (r - b)) / (2 * np.sqrt((r - g) ** 2 + (r - b) * (g - b))))
pi_2 = 2 * np.pi
hsi[:, :, 0] = np.where(g >= b, theta, pi_2 - theta) / pi_2
hsi[:, :, 1] = 1 - 3 * np.min(rgb, 2) / np.sum(rgb, 2)
hsi[:, :, 2] = np.sum(rgb, 2) / 3.
return hsi * 255.
def resize_create_hsi_img(dir_pair):
src_path, target_path_rgb, target_path_hsi = dir_pair
rgb_not_exist = not os.path.exists(target_path_rgb)
hsi_not_exist = not os.path.exists(target_path_hsi)
try:
if rgb_not_exist or hsi_not_exist:
org_pic = Image.open(src_path)
new_size = int(org_pic.size[0] * 0.7), int(org_pic.size[1] * 0.7)
pic = org_pic.resize(new_size, Image.ANTIALIAS)
if rgb_not_exist:
pic.save(target_path_rgb, quality=75)
if hsi_not_exist:
if pic.mode == 'RGB':
rgb_img = np.asarray(pic, np.float32)
hsi_img = rgb2hsi(rgb_img).astype(np.uint8)
pic = Image.fromarray(hsi_img)
pic.save(target_path_hsi, quality=75)
return None
except Exception as exc:
print(exc)
return src_path
def walk_all_pic():
root = 'D:\Datasets\ImageNet\ILSVRC2017_CLS-LOC\ILSVRC\Data\CLS-LOC'
targets = ['val', 'train', ]
root_new1 = 'E:\Imagenet\cls_rgb'
root_new2 = 'E:\Imagenet\cls_hsi'
if not os.path.exists(root_new1):
os.mkdir(root_new1)
if not os.path.exists(root_new2):
os.mkdir(root_new2)
for t in targets:
sub1 = os.path.join(root, t)
sub1_new1 = os.path.join(root_new1, t)
sub1_new2 = os.path.join(root_new2, t)
folders = os.listdir(sub1)
if not os.path.exists(sub1_new1):
os.mkdir(sub1_new1)
if not os.path.exists(sub1_new2):
os.mkdir(sub1_new2)
for subfolder in folders:
sub2 = os.path.join(sub1, subfolder)
sub2_new1 = os.path.join(sub1_new1, subfolder)
sub2_new2 = os.path.join(sub1_new2, subfolder)
if os.path.isdir(sub2):
if not os.path.exists(sub2_new1):
os.mkdir(sub2_new1)
if os.path.isdir(sub2):
if not os.path.exists(sub2_new2):
os.mkdir(sub2_new2)
files = os.listdir(sub2)
for file in files:
fpath = os.path.join(sub2, file)
fpath_new_rgb = os.path.join(sub2_new1, file)
fpath_new_hsi = os.path.join(sub2_new2, file)
if os.path.isfile(fpath):
yield fpath, fpath_new_rgb, fpath_new_hsi
# Multiple process version
def run_multiprocess():
print('Processing pictures with multiple processors...')
error_pics = []
with mp.Pool(processes=mp.cpu_count()) as pool:
for ep in pool.imap_unordered(resize_create_hsi_img, tqdm.tqdm(walk_all_pic(), total=1331167, ncols=65)):
error_pics.append(ep)
with open('./error_pics.log', mode='w') as f:
if error_pics is not None:
f.writelines(error_pics)
print('All pictures cannot be processed have been writen into \'error_pics.log\'')
# Single process version
def run():
for f, fnew1, fnew2 in tqdm.tqdm(walk_all_pic()):
resize_create_hsi_img((f, fnew1, fnew2))
if __name__ == '__main__':
run_multiprocess()
网友评论