美文网首页
数据集处理工具

数据集处理工具

作者: 张亦风 | 来源:发表于2019-10-18 09:35 被阅读0次

    划分训练测试集

    import os
    import random
    full_path="full.txt"
    train_path="save_dir/train.txt"
    test_path="save_dir/test.txt"
    with open(full_path,"r")as f1:
        line=[i.replace("\n","") for i in f1.readlines()]
    random.shuffle(line)
    trian_num=int(len(line)*0.9)
    for i in line[:train]:
        with open(train_path,"a+") as f:
            f.write(i+"\n")
    for i in line[train:]:
        with open(test_path,"a+") as f:
            f.write(i+"\n")
    

    修改json文件

    import os,sys
    import json
    m_path='conf.json'
    def get_new_json(filepath):
        with open(filepath, 'rb') as f:
            data = json.load(f)
            data["conf"][1]["value"]=0.99
            data["conf"][2]["value"]=0.75
        return data
    def rewrite_json_file(filepath,json_data):
        with open(filepath, 'w') as f:
            json.dump(json_data,f)
    if __name__ == '__main__':
        m_json_data = get_new_json(m_path)
        rewrite_json_file("/disk1/1.json",m_json_data)
    

    dict to json

    import json
    
    data = {
        'name' : {"0":{"name":'ACME',"age":16}},
        'shares' : 100,
        'price' : 542.23
    }
    with open("2.json","w") as f:
        
        json.dump(data,f)
    

    base.txt2yolo

    import os
    import cv2
    from random import randint
    
    base_labels_dir = '/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/base'
    images_dir = '/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages'
    
    alllabels = os.listdir(base_labels_dir)
    allimages = os.listdir(images_dir)
    
    trainf = open('train3.txt','a+')
    testf = open('test3.txt','a+')
    
    for jpg in allimages:
        
        txt = jpg.replace('jpg','txt')
        if txt not in alllabels:
            continue
        if randint(0,10) is 1:
            testf.write('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages/%s\n' % jpg)
        else:
            trainf.write('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages/%s\n' % jpg)
        
        jpg = os.path.join(images_dir, jpg)
        jpg = cv2.imread(jpg)
        h,w,_=jpg.shape
        
        labf = open(os.path.join('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/Label',txt), 'w')
        txt = os.path.join(base_labels_dir, txt)
        txt = open(txt, 'r')
        lines = txt.readlines()
        for line in lines:
            box = [int(i) for i in line.split(' ')]
            box[1] = max(0, box[2])
            box[2] = max(0, box[1])
            box[3] = min(w-1, box[4])
            box[4] = min(h-1, box[3])
    
            cx = (box[1]+box[3])/2
            cy = (box[2]+box[4])/2
            cw = box[3]-box[1]
            ch = box[4]-box[2]
    
            cx = float(cx)/float(w)
            cy = float(cy)/float(h)
            cw = float(cw)/float(w)
            ch = float(ch)/float(h)
    
            labf.write('%d %.6f %.6f %.6f %.6f\n' % (box[0]-1,cx,cy,cw,ch))
        labf.close()
        
    trainf.close()
    testf.close()
    
    

    txt2xml

    import os
    from PIL import Image
    import cv2
    
    out0 ='''<?xml version="1.0" encoding="utf-8"?>
    <annotation>
        <folder>None</folder>
        <filename>%(name)s</filename>
        <source>
            <database>None</database>
            <annotation>None</annotation>
            <image>None</image>
            <flickrid>None</flickrid>
        </source>
        <owner>
            <flickrid>None</flickrid>
            <name>None</name>
        </owner>
        <segmented>0</segmented>
        <size>
            <width>%(width)d</width>
            <height>%(height)d</height>
            <depth>3</depth>
        </size>
    '''
    out1 = '''  <object>
            <name>%(class)s</name>
            <pose>Unspecified</pose>
            <truncated>0</truncated>
            <difficult>0</difficult>
            <bndbox>
                <xmin>%(xmin)d</xmin>
                <ymin>%(ymin)d</ymin>
                <xmax>%(xmax)d</xmax>
                <ymax>%(ymax)d</ymax>
            </bndbox>
        </object>
    '''
    
    out2 = '''</annotation>
    '''
    def translate(lists): 
        source = {}
        label = {}
        for jpg in lists:
            if os.path.splitext(jpg)[1] == '.jpg':
                print(jpg)
                jpg=jpg.replace('darknet','ssd/caffe')
                img=cv2.imread(jpg)
                h,w,_=img.shape[:]
                fxml = jpg.replace('ImageSets','Annotations')
                fxml = fxml.replace('.jpg','.xml')
                fxml = open(fxml, 'w');
    
                imgfile = jpg.split('/')[-1]
                source['name'] = imgfile
                source['width'] = w
                source['height'] = h
    
                fxml.write(out0 % source)
    
                txt = jpg.replace('.jpg','.txt')
                txt=txt.replace("ImageSets",'base')
                with open(txt,'r') as f:
                    lines = [i.replace('\n','') for i in f.readlines()]
                    print(lines)
                for box in lines:
                    box = box.split(' ')
                    name=int(box[0])
                    label['class'] =name             
                    
                    label['xmin'] = max(int(box[2]),0)
                    label['ymin'] = max(int(box[1]),0)
                    label['xmax'] = min(int(box[4]),w-1)
                    label['ymax'] = min(int(box[3]),h-1)
                    
                    if label['xmin']>=w or label['ymin']>=h or label['xmax']>=w or label['ymax']>=h:
                        continue
                    if label['xmin']<0 or label['ymin']<0 or label['xmax']<0 or label['ymax']<0:
                        continue
                        
                    fxml.write(out1 % label)
                    
                fxml.write(out2)
    
    if __name__ == '__main__':
        with open('/home/zhangyong/newdisk/company/week/codes/tools/train.txt','r') as f:
            lines = [i.replace('\n','') for i in f.readlines()]
            
        translate(lines)
    
    

    相关文章

      网友评论

          本文标题:数据集处理工具

          本文链接:https://www.haomeiwen.com/subject/hbabmctx.html