美文网首页
图像预处理-随机贴图生成标注文件的python实现

图像预处理-随机贴图生成标注文件的python实现

作者: 智驱力AI | 来源:发表于2023-02-08 17:41 被阅读0次

    1. 任务目标

    在训练目标检测模型时,若数据存在以下情况:图像之间差异小、不同类别数目差异大、有些目标物体的样本图片难以搜集等,需要对数据进行处理。本文以fire类别为例实现对数据的随机贴图增广,生成新的标注文件,或在已有标注文件中添加,且避免覆盖已有标注。

    2. Python实现

    2.1 将已标注的目标保存

    数据存储格式:(路径中不要包含中文)

    输入文件夹:

    • data/fire_dataset/JPEGImages/***.jpg

    • data/fire_dataset/Annotations/***.xml

    输出文件夹:

    • data/fire_cut

    代码:

    import os
    import cv2
    import time
    import argparse
    import xml.etree.ElementTree as ET
    
    from tqdm import tqdm
    
    parser = argparse.ArgumentParser(description='Read box from xml and crop from image.')
    parser.add_argument('--dst-label', default='fire', help='label box to cut')
    parser.add_argument('--input-path', default='data/fire_dataset', help='contain Annotations, JPEGImages folder')
    parser.add_argument('--output-path', default='data/fire_cut', help='output path')
    args = parser.parse_args()
    
    
    def read_xml_box(xml_file):
        xml_anno = ET.parse(xml_file)
        result = []
        for obj in xml_anno.findall('object'):
            class_name = obj.find('name').text.strip()
            xmin = obj.find('bndbox').find('xmin').text
            xmax = obj.find('bndbox').find('xmax').text
            ymin = obj.find('bndbox').find('ymin').text
            ymax = obj.find('bndbox').find('ymax').text
            result.append([class_name, int(xmin), int(xmax), int(ymin), int(ymax)])
        return result
    
    
    def main():
        xml_path = os.path.join(args.input_path, "Annotations")
        img_path = os.path.join(args.input_path, "JPEGImages")
        for img_name in tqdm(os.listdir(img_path)):
            xml_name = '{}.xml'.format(img_name.rsplit('.', maxsplit=1)[0])
            xml_file = os.path.join(xml_path, xml_name)
            if not os.path.exists(xml_file):
                print('{} not exists'.format(xml_name))
                continue
            img = cv2.imread(os.path.join(img_path, img_name))
            cls_boxes = read_xml_box(xml_file)  
            for cls_box in cls_boxes:
                class_name, xmin, xmax, ymin, ymax = cls_box
                if class_name != args.dst_label:
                    continue
                crop_img = img[ymin:ymax, xmin:xmax, :]
                new_name = '{}_{}.jpg'.format(xml_name[:-4], str(time.time()).replace('.', ''))
                save_path = os.path.join(args.output_path, class_name)
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                save_name = os.path.join(save_path, new_name)    
                print(save_name)
                cv2.imwrite(save_name, crop_img)
    
    if __name__ == '__main__':
        main()
    

    生成结果:

    2.2 随机贴图扩充数据

    数据存储格式:

    输入文件夹:

    • data/fire_cut
    • data/fire_bg/JPEGImages/***.jpg
    • data/fire_bg/Annotations/***.xml

    注意:此处作为背景的数据,可以只有图片,没有标注文件。也可以既有图片又有标注文件。若没有标注文件,则生成;若有标注文件,则在贴图时会避免目标框的遮挡。

    输出文件夹:

    • data/output

    输入参数:

    • gen_num:生成的图片数目,实际生成数目小于等于该值;
    • cls_name:待扩增的类别名称;
    • tietudir:前景图片文件夹,保存fire目标图,是从已标注的fire数据中剪切保存;
    • img_path:背景图片文件夹,其他含fire目标或不含fire目标的图片;
    • xml_path:背景图片标注文件夹
    • save_path:保存结果文件夹

    代码:

    import os
    import time
    import random
    import argparse
    import xml.etree.ElementTree as ET
    
    from PIL import Image
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--tietudir', default='data/fire_cut/fire', help='贴图路径') # 
    parser.add_argument('--xml_path', default='data/fire_bg/Annotations', help='躲避路径')
    parser.add_argument('--img_path', default='data/fire_bg/JPEGImages', help='底图路径')
    parser.add_argument('--save_path', default='data/output', help='保存路径')
    parser.add_argument('--gen_num', default=11, help='保存个数')
    parser.add_argument('--cls_name', default='fire', help='目标类别名')
    args = parser.parse_args()
    
    
    def creat_src_xml(width_ditu,height_ditu,box,save_path, name, cls_name):
        xml_name = name
        xml_file = save_path + '/' + xml_name
        x = open(xml_file, 'w')
        x.writelines('<annotation>\n')
        x.writelines('    <folder>data</folder>\n')
        x.writelines('    <filename>' + xml_name + '</filename>\n')
        x.writelines('    <path>'  + xml_file + '</path>\n')
        x.writelines('    <source>\n')
        x.writelines('        <database>Unknown</database>\n')
        x.writelines('    </source>\n')
        x.writelines('    <size>\n')
        x.writelines('        <width>' + str(width_ditu) + '</width>\n')
        x.writelines('        <height>' + str(height_ditu) + '</height>\n')
        x.writelines('        <depth>3</depth>\n')
        x.writelines('    </size>\n')
        x.writelines('    <segmented>0</segmented>\n')
        x.writelines('    <object>\n')
        x.writelines('        <name>' + cls_name + '</name>\n')
        x.writelines('        <pose>Unspecified</pose>\n')
        x.writelines('        <truncated>0</truncated>\n')
        x.writelines('        <difficult>0</difficult>\n')
        x.writelines('        <bndbox>\n')
        x.writelines('            <xmin>' + str(int(box[0])) + '</xmin>\n')
        x.writelines('            <ymin>' + str(int(box[1])) + '</ymin>\n')
        x.writelines('            <xmax>' + str(int(box[2])) + '</xmax>\n')
        x.writelines('            <ymax>' + str(int(box[3])) + '</ymax>\n')
        x.writelines('        </bndbox>\n')
        x.writelines('    </object>\n')    
        x.writelines('</annotation>\n')
          x.close()
    
    def creat_xml(box, save_path, copy_path, cls_name):
        readFile = open(copy_path, encoding='UTF-8')
        lines = readFile.readlines()
        readFile.close()
        x = open(save_path, 'w', encoding='UTF-8')
        x.writelines([item for item in lines[:-1]])
        x.writelines('    <object>\n')
        x.writelines('        <name>' + cls_name + '</name>\n')
        x.writelines('        <pose>Unspecified</pose>\n')
        x.writelines('        <truncated>0</truncated>\n')
        x.writelines('        <difficult>0</difficult>\n')
        x.writelines('        <bndbox>\n')
        x.writelines('            <xmin>' + str(int(box[0])) + '</xmin>\n')
        x.writelines('            <ymin>' + str(int(box[1])) + '</ymin>\n')
        x.writelines('            <xmax>' + str(int(box[2])) + '</xmax>\n')
        x.writelines('            <ymax>' + str(int(box[3])) + '</ymax>\n')
        x.writelines('        </bndbox>\n')
        x.writelines('    </object>\n')
        x.writelines('</annotation>\n')
          x.close()
    
    def read_xml_box(xml_file):
        xml_anno = ET.parse(xml_file)
        result = []
        for obj in xml_anno.findall('object'):
            class_name = obj.find('name').text.strip()
            xmin = obj.find('bndbox').find('xmin').text
            xmax = obj.find('bndbox').find('xmax').text
            ymin = obj.find('bndbox').find('ymin').text
            ymax = obj.find('bndbox').find('ymax').text
            result.append([class_name, int(xmin), int(ymin), int(xmax), int(ymax)])
          return result
    
    def compute_IOU(rec1, rec2):
        left_column_max = max(rec1[0], rec2[0])
        right_column_min = min(rec1[2], rec2[2])
        up_row_max = max(rec1[1], rec2[1])
        down_row_min = min(rec1[3], rec2[3])
        if left_column_max >= right_column_min or down_row_min <= up_row_max:
            return 0
        else:
            s1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
            s2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
            s_cross = (down_row_min - up_row_max) * (right_column_min - left_column_max)
            return s_cross / (s1 + s2 - s_cross)
    
    def random_box(end1, end2, end3, end4):
        xmin = random.randint(0, end1)
        ymin = random.randint(0, end2)
        xmax = xmin + end3
        ymax = ymin + end4
          return xmin, ymin, xmax, ymax
    
    def get_shuffle_list(img_path, gen_num):
        imgs = os.listdir(img_path)
        num = len(imgs)
        if not num:
            return None
        random.shuffle(imgs)
        times, remainder = divmod(gen_num, num)
        name_gen = imgs[:remainder]
        for i in range(times):
            name_gen.extend(imgs)
        random.shuffle(name_gen)
          return name_gen
    
    def process(gen_num, tietudir, img_path, xml_path, save_path, cls_name):
        os.makedirs(save_path, exist_ok=True)
        # load fg
        fg_gen = get_shuffle_list(tietudir, gen_num)
        # load bg
        bg_gen = get_shuffle_list(img_path, gen_num)
        if fg_gen is None or bg_gen is None:
            return
        # combine
        num = 0
        for fg, bg in zip(fg_gen, bg_gen):
            num += 1
            fg_img = Image.open(os.path.join(tietudir, fg))
            bg_img = Image.open(os.path.join(img_path, bg))
            save_name = bg.rsplit('.', maxsplit=1)[0]
            bg_xml_name = '{}.xml'.format(save_name)
            bg_xml_path = os.path.join(xml_path, bg_xml_name)
            cls_boxes = []
            if os.path.exists(bg_xml_path): 
                cls_boxes = read_xml_box(bg_xml_path)
        
            fg_width = fg_img.size[0]  # 贴图长宽
            fg_height = fg_img.size[1]
            bg_width = bg_img.size[0]  # 底图长宽
            bg_height = bg_img.size[1]
            scale = 1.0
            try:
                box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
            except:
                scale = 0.5
                fg_img.resize((int(fg_width*scale), int(fg_height*scale)))
                fg_width = fg_img.size[0] 
                fg_height = fg_img.size[1]
                if fg_width > bg_width or fg_height > bg_height:
                    continue
                box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
    
            timer = 50
            loop_flag = True
            while timer and loop_flag and cls_boxes:
                timer -= 1
                loop_flag = False
                for box in cls_boxes:
                    area = compute_IOU(box[1:5], box2)
                    if area > 0:
                        box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
                        loop_flag = True
                        break
            if timer > 0:
                bg_img.paste(fg_img, (box2[0], box2[1]))
                bg_img_add = bg_img.convert('RGB')
                new_name = '{}_{}'.format(save_name, str(time.time()).replace('.', ''))
                bg_img_add.save(os.path.join(save_path, new_name + '.jpg'))
                xml_save_path = os.path.join(save_path, new_name + '.xml')   
                
                if not cls_boxes:
                    creat_src_xml(bg_width, bg_height, box2, save_path, new_name+'.xml', cls_name)
                else:
                    creat_xml(box2, xml_save_path, bg_xml_path, cls_name)
      
        
    if __name__ == "__main__":
        process(args.gen_num, 
                args.tietudir, 
                args.img_path, 
                args.xml_path, 
                args.save_path,
                args.cls_name)
    

    生成结果:(labelimg查看)

    智驱力-科技驱动生产力

    相关文章

      网友评论

          本文标题:图像预处理-随机贴图生成标注文件的python实现

          本文链接:https://www.haomeiwen.com/subject/luwqkdtx.html