1. 任务目标
在训练目标检测模型时,若数据存在以下情况:图像之间差异小、不同类别数目差异大、有些目标物体的样本图片难以搜集等,需要对数据进行处理。本文以fire类别为例实现对数据的随机贴图增广,生成新的标注文件,或在已有标注文件中添加,且避免覆盖已有标注。
2. Python实现
2.1 将已标注的目标保存
数据存储格式:(路径中不要包含中文)
输入文件夹:
-
data/fire_dataset/JPEGImages/***.jpg
-
data/fire_dataset/Annotations/***.xml
输出文件夹:
- data/fire_cut
代码:
import os
import cv2
import time
import argparse
import xml.etree.ElementTree as ET
from tqdm import tqdm
parser = argparse.ArgumentParser(description='Read box from xml and crop from image.')
parser.add_argument('--dst-label', default='fire', help='label box to cut')
parser.add_argument('--input-path', default='data/fire_dataset', help='contain Annotations, JPEGImages folder')
parser.add_argument('--output-path', default='data/fire_cut', help='output path')
args = parser.parse_args()
def read_xml_box(xml_file):
xml_anno = ET.parse(xml_file)
result = []
for obj in xml_anno.findall('object'):
class_name = obj.find('name').text.strip()
xmin = obj.find('bndbox').find('xmin').text
xmax = obj.find('bndbox').find('xmax').text
ymin = obj.find('bndbox').find('ymin').text
ymax = obj.find('bndbox').find('ymax').text
result.append([class_name, int(xmin), int(xmax), int(ymin), int(ymax)])
return result
def main():
xml_path = os.path.join(args.input_path, "Annotations")
img_path = os.path.join(args.input_path, "JPEGImages")
for img_name in tqdm(os.listdir(img_path)):
xml_name = '{}.xml'.format(img_name.rsplit('.', maxsplit=1)[0])
xml_file = os.path.join(xml_path, xml_name)
if not os.path.exists(xml_file):
print('{} not exists'.format(xml_name))
continue
img = cv2.imread(os.path.join(img_path, img_name))
cls_boxes = read_xml_box(xml_file)
for cls_box in cls_boxes:
class_name, xmin, xmax, ymin, ymax = cls_box
if class_name != args.dst_label:
continue
crop_img = img[ymin:ymax, xmin:xmax, :]
new_name = '{}_{}.jpg'.format(xml_name[:-4], str(time.time()).replace('.', ''))
save_path = os.path.join(args.output_path, class_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_name = os.path.join(save_path, new_name)
print(save_name)
cv2.imwrite(save_name, crop_img)
if __name__ == '__main__':
main()
生成结果:
2.2 随机贴图扩充数据
数据存储格式:
输入文件夹:
- data/fire_cut
- data/fire_bg/JPEGImages/***.jpg
- data/fire_bg/Annotations/***.xml
注意:此处作为背景的数据,可以只有图片,没有标注文件。也可以既有图片又有标注文件。若没有标注文件,则生成;若有标注文件,则在贴图时会避免目标框的遮挡。
输出文件夹:
- data/output
输入参数:
- gen_num:生成的图片数目,实际生成数目小于等于该值;
- cls_name:待扩增的类别名称;
- tietudir:前景图片文件夹,保存fire目标图,是从已标注的fire数据中剪切保存;
- img_path:背景图片文件夹,其他含fire目标或不含fire目标的图片;
- xml_path:背景图片标注文件夹
- save_path:保存结果文件夹
代码:
import os
import time
import random
import argparse
import xml.etree.ElementTree as ET
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--tietudir', default='data/fire_cut/fire', help='贴图路径') #
parser.add_argument('--xml_path', default='data/fire_bg/Annotations', help='躲避路径')
parser.add_argument('--img_path', default='data/fire_bg/JPEGImages', help='底图路径')
parser.add_argument('--save_path', default='data/output', help='保存路径')
parser.add_argument('--gen_num', default=11, help='保存个数')
parser.add_argument('--cls_name', default='fire', help='目标类别名')
args = parser.parse_args()
def creat_src_xml(width_ditu,height_ditu,box,save_path, name, cls_name):
xml_name = name
xml_file = save_path + '/' + xml_name
x = open(xml_file, 'w')
x.writelines('<annotation>\n')
x.writelines(' <folder>data</folder>\n')
x.writelines(' <filename>' + xml_name + '</filename>\n')
x.writelines(' <path>' + xml_file + '</path>\n')
x.writelines(' <source>\n')
x.writelines(' <database>Unknown</database>\n')
x.writelines(' </source>\n')
x.writelines(' <size>\n')
x.writelines(' <width>' + str(width_ditu) + '</width>\n')
x.writelines(' <height>' + str(height_ditu) + '</height>\n')
x.writelines(' <depth>3</depth>\n')
x.writelines(' </size>\n')
x.writelines(' <segmented>0</segmented>\n')
x.writelines(' <object>\n')
x.writelines(' <name>' + cls_name + '</name>\n')
x.writelines(' <pose>Unspecified</pose>\n')
x.writelines(' <truncated>0</truncated>\n')
x.writelines(' <difficult>0</difficult>\n')
x.writelines(' <bndbox>\n')
x.writelines(' <xmin>' + str(int(box[0])) + '</xmin>\n')
x.writelines(' <ymin>' + str(int(box[1])) + '</ymin>\n')
x.writelines(' <xmax>' + str(int(box[2])) + '</xmax>\n')
x.writelines(' <ymax>' + str(int(box[3])) + '</ymax>\n')
x.writelines(' </bndbox>\n')
x.writelines(' </object>\n')
x.writelines('</annotation>\n')
x.close()
def creat_xml(box, save_path, copy_path, cls_name):
readFile = open(copy_path, encoding='UTF-8')
lines = readFile.readlines()
readFile.close()
x = open(save_path, 'w', encoding='UTF-8')
x.writelines([item for item in lines[:-1]])
x.writelines(' <object>\n')
x.writelines(' <name>' + cls_name + '</name>\n')
x.writelines(' <pose>Unspecified</pose>\n')
x.writelines(' <truncated>0</truncated>\n')
x.writelines(' <difficult>0</difficult>\n')
x.writelines(' <bndbox>\n')
x.writelines(' <xmin>' + str(int(box[0])) + '</xmin>\n')
x.writelines(' <ymin>' + str(int(box[1])) + '</ymin>\n')
x.writelines(' <xmax>' + str(int(box[2])) + '</xmax>\n')
x.writelines(' <ymax>' + str(int(box[3])) + '</ymax>\n')
x.writelines(' </bndbox>\n')
x.writelines(' </object>\n')
x.writelines('</annotation>\n')
x.close()
def read_xml_box(xml_file):
xml_anno = ET.parse(xml_file)
result = []
for obj in xml_anno.findall('object'):
class_name = obj.find('name').text.strip()
xmin = obj.find('bndbox').find('xmin').text
xmax = obj.find('bndbox').find('xmax').text
ymin = obj.find('bndbox').find('ymin').text
ymax = obj.find('bndbox').find('ymax').text
result.append([class_name, int(xmin), int(ymin), int(xmax), int(ymax)])
return result
def compute_IOU(rec1, rec2):
left_column_max = max(rec1[0], rec2[0])
right_column_min = min(rec1[2], rec2[2])
up_row_max = max(rec1[1], rec2[1])
down_row_min = min(rec1[3], rec2[3])
if left_column_max >= right_column_min or down_row_min <= up_row_max:
return 0
else:
s1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
s2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
s_cross = (down_row_min - up_row_max) * (right_column_min - left_column_max)
return s_cross / (s1 + s2 - s_cross)
def random_box(end1, end2, end3, end4):
xmin = random.randint(0, end1)
ymin = random.randint(0, end2)
xmax = xmin + end3
ymax = ymin + end4
return xmin, ymin, xmax, ymax
def get_shuffle_list(img_path, gen_num):
imgs = os.listdir(img_path)
num = len(imgs)
if not num:
return None
random.shuffle(imgs)
times, remainder = divmod(gen_num, num)
name_gen = imgs[:remainder]
for i in range(times):
name_gen.extend(imgs)
random.shuffle(name_gen)
return name_gen
def process(gen_num, tietudir, img_path, xml_path, save_path, cls_name):
os.makedirs(save_path, exist_ok=True)
# load fg
fg_gen = get_shuffle_list(tietudir, gen_num)
# load bg
bg_gen = get_shuffle_list(img_path, gen_num)
if fg_gen is None or bg_gen is None:
return
# combine
num = 0
for fg, bg in zip(fg_gen, bg_gen):
num += 1
fg_img = Image.open(os.path.join(tietudir, fg))
bg_img = Image.open(os.path.join(img_path, bg))
save_name = bg.rsplit('.', maxsplit=1)[0]
bg_xml_name = '{}.xml'.format(save_name)
bg_xml_path = os.path.join(xml_path, bg_xml_name)
cls_boxes = []
if os.path.exists(bg_xml_path):
cls_boxes = read_xml_box(bg_xml_path)
fg_width = fg_img.size[0] # 贴图长宽
fg_height = fg_img.size[1]
bg_width = bg_img.size[0] # 底图长宽
bg_height = bg_img.size[1]
scale = 1.0
try:
box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
except:
scale = 0.5
fg_img.resize((int(fg_width*scale), int(fg_height*scale)))
fg_width = fg_img.size[0]
fg_height = fg_img.size[1]
if fg_width > bg_width or fg_height > bg_height:
continue
box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
timer = 50
loop_flag = True
while timer and loop_flag and cls_boxes:
timer -= 1
loop_flag = False
for box in cls_boxes:
area = compute_IOU(box[1:5], box2)
if area > 0:
box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
loop_flag = True
break
if timer > 0:
bg_img.paste(fg_img, (box2[0], box2[1]))
bg_img_add = bg_img.convert('RGB')
new_name = '{}_{}'.format(save_name, str(time.time()).replace('.', ''))
bg_img_add.save(os.path.join(save_path, new_name + '.jpg'))
xml_save_path = os.path.join(save_path, new_name + '.xml')
if not cls_boxes:
creat_src_xml(bg_width, bg_height, box2, save_path, new_name+'.xml', cls_name)
else:
creat_xml(box2, xml_save_path, bg_xml_path, cls_name)
if __name__ == "__main__":
process(args.gen_num,
args.tietudir,
args.img_path,
args.xml_path,
args.save_path,
args.cls_name)
生成结果:(labelimg查看)
网友评论