美文网首页
生成yolov3-tensorflow版本需要的训练数据格式

生成yolov3-tensorflow版本需要的训练数据格式

作者: chunleiml | 来源:发表于2020-03-02 14:17 被阅读0次

    在有训练需要的图片数据和含有标注信息的XML文件的情况下,生成yolov3-tensorflow需要的训练数据
    https://github.com/YunYang1994/tensorflow-yolov3

    import xml.etree.ElementTree as ET
    import os
    
    CLASSES = ['plate', 'b-plate-number','b-plate']
    
    def convert_xml_annotation(img_path, xml_path, classes):
        xml_dir = []
        imgs = os.listdir(img_path)
        for xml in os.listdir(xml_path):
            if xml.endswith('.xml'):
                xml_dir.append(xml)
        print("Total xml files : ", len(xml_dir))
        with open("./plate_all.txt", 'w') as f:
            for i in range(len(xml_dir)):
                tree = ET.parse(xml_path + xml_dir[i])
                root = tree.getroot()
                name ,b = os.path.splitext(xml_dir[i])
                if name+'.jpg' in imgs:
                    filename = name+'.jpg'
                else:
                    filename = name+'.png'
                
    
                # image path
    #            filename = root.find('filename').text
                image_path = img_path + filename
                annotation = image_path
    
                # coordinates of label : xmin  ymin  xmax  ymax
                for obj in root.iter('object'):
    #                difficult = obj.find('difficult').text
    #                cls = obj.find('name').text
    #                if cls not in classes or int(difficult) == 1:
    #                    continue
    #                cls_id = classes.index(cls)
                    cls_id = 0
                    bbox = obj.find('bndbox')
                    xmin = bbox.find('xmin').text.strip()
                    xmax = bbox.find('xmax').text.strip()
                    ymin = bbox.find('ymin').text.strip()
                    ymax = bbox.find('ymax').text.strip()
                    annotation += ' ' + ','.join([xmin, ymin, xmax, ymax,str(cls_id)])
                print(annotation)
                f.write(annotation + "\n")
    
    xml_path = './plate_train_xmls/'
    img_path = './plate_train_imgs/'
    convert_xml_annotation(img_path, xml_path, CLASSES)
    

    最后把生成的plate_all.txt分成训练集和测试集

    import random
    import numpy as np
    path = './plate_all.txt'
    train_path = './plate_train.txt'
    valid_path = './plate_test.txt'
    file = open(path,'r')
    file_train = open(train_path,'a')
    file_valid = open(valid_path,'a')
    lines = file.readlines()
    num = np.arange(0,len(lines))
    random.shuffle(num)
    a = int(len(lines)*0.2)
    print(a)
    for i in range(len(lines)):
        if i < a:
            print(lines[num[i]])
            file_valid.write(lines[num[i]])
        else:
            file_train.write(lines[num[i]])
    

    相关文章

      网友评论

          本文标题:生成yolov3-tensorflow版本需要的训练数据格式

          本文链接:https://www.haomeiwen.com/subject/uxgmkhtx.html