美文网首页
yolo v3 源码阅读(2):数据格式与加载

yolo v3 源码阅读(2):数据格式与加载

作者: 寒夏凉秋 | 来源:发表于2018-10-19 20:15 被阅读0次

    load_data_in_thread 方法去加载数据到 args.d 指针所指缓冲区中

    #data.c
    pthread_t load_data_in_thread(load_args args)
    {
        pthread_t thread;
        struct load_args *ptr = calloc(1, sizeof(struct load_args));
        *ptr = args;
        if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
        return thread;
    }
    

    执行load_thread方法,开启线程加载数据

    //data.c
    void *load_thread(void *ptr)
    {
        //printf("Loading data: %d\n", rand());
        load_args a = *(struct load_args*)ptr;
        if(a.exposure == 0) a.exposure = 1;
        if(a.saturation == 0) a.saturation = 1;
        if(a.aspect == 0) a.aspect = 1;
    
        if (a.type == OLD_CLASSIFICATION_DATA){
            *a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
        } else if (a.type == REGRESSION_DATA){
            *a.d = load_data_regression(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
        } else if (a.type == CLASSIFICATION_DATA){
            *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.center);
        } else if (a.type == SUPER_DATA){
            *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
        } else if (a.type == WRITING_DATA){
            *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
        } else if (a.type == ISEG_DATA){
            *a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.scale, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
        } else if (a.type == INSTANCE_DATA){
            *a.d = load_data_mask(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
        } else if (a.type == SEGMENTATION_DATA){
            *a.d = load_data_seg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.scale);
        } else if (a.type == REGION_DATA){
            *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
        } else if (a.type == DETECTION_DATA){
            *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
        } else if (a.type == SWAG_DATA){
            *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
        } else if (a.type == COMPARE_DATA){
            *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
        } else if (a.type == IMAGE_DATA){
            *(a.im) = load_image_color(a.path, 0, 0);
            *(a.resized) = resize_image(*(a.im), a.w, a.h);
        } else if (a.type == LETTERBOX_DATA){
            *(a.im) = load_image_color(a.path, 0, 0);
            *(a.resized) = letterbox_image(*(a.im), a.w, a.h);
        } else if (a.type == TAG_DATA){
            *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
        }
        free(ptr);
        return 0;
    }
    

    以上代码中,根据args的type 属性决定了 调用哪个方法去执行load_data;

    通过加断点,我们发现 运行yolo train 的时候,调用的是 load_data_detection

    //data.c
    data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure)
    {
        /*
         * n batch_size
         * paths 图片路径
         * m 一次取多少张图片
         * w,h 输入图片的宽高
         * boxes  根据输入网络的w,h 来决定 y 向量中数目
         */
    
        //随机获取到图片path
       // printf("n: %d m: %d  w: %d h:%d boxes:%d classes:%d jiter:");
        char **random_paths = get_random_paths(paths, n, m);
        int i;
        data d = {0};
        d.shallow = 0;
    
        //一次加载N 张图片
        d.X.rows = n;
        d.X.vals = calloc(d.X.rows, sizeof(float*));
        //每张图片  w * h * 3
        d.X.cols = h*w*3;
    
        //真实y 的数据是  n 维 ,5 * boxes 数量
        d.y = make_matrix(n, 5*boxes);
        for(i = 0; i < n; ++i){
            //加载图片
            image orig = load_image_color(random_paths[i], 0, 0);
            
    
            //得到resize 后的 空 图
            image sized = make_image(w, h, orig.c);
            //填充一半图片
            fill_image(sized, .5);
    
            /*
             * 这里是为数据添加抖动干扰,提高网络的泛化能力(其实就是crop,数据增广的一种).
            配置文件的jitter=0.2,则宽高最多裁剪掉或者增加原始宽高的1/5.
            */
    
            float dw = jitter * orig.w;
            float dh = jitter * orig.h;
    
            //这里进行产生随机值
            float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh));
            //float scale = rand_uniform(.25, 2);
            float scale = 1;
    
            float nw, nh;
    
            //宽小于高
            if(new_ar < 1){
                nh = scale * h;
                nw = nh * new_ar;
            } else {
                nw = scale * w;
                nh = nw / new_ar;
            }
    
            float dx = rand_uniform(0, w - nw);
            float dy = rand_uniform(0, h - nh);
            //对图片进行裁剪,resized 后的图像保存在sized
            place_image(orig, nw, nh, dx, dy, sized);
    
            //对图片进色调、曝光度等的调整
            random_distort_image(sized, hue, saturation, exposure);
    
            int flip = rand()%2;
            if(flip) flip_image(sized);
            d.X.vals[i] = sized.data;
    
    
            
            //图像进行变换抖动处理,需要对标签进行还原
           
            fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, -dx/w, -dy/h, nw/w, nh/h);
    
            free_image(orig);
        }
        free(random_paths);
        return d;
    }
    
    
    //读取box信息,根据之前的图像变换  改变 x,y,w,h
    void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
    {
        char labelpath[4096];
        find_replace(path, "images", "labels", labelpath);
        find_replace(labelpath, "JPEGImages", "labels", labelpath);
    
        find_replace(labelpath, "raw", "labels", labelpath);
        find_replace(labelpath, ".jpg", ".txt", labelpath);
        find_replace(labelpath, ".png", ".txt", labelpath);
        find_replace(labelpath, ".JPG", ".txt", labelpath);
        find_replace(labelpath, ".JPEG", ".txt", labelpath);
        int count = 0;
        //获取到box 信息
        box_label *boxes = read_boxes(labelpath, &count);
        //对label 信息进行处理
        randomize_boxes(boxes, count);
        //将label 还原到 变形后的图像中去
        correct_boxes(boxes, count, dx, dy, sx, sy, flip);
        //一个图最多90个框
        if(count > num_boxes) count = num_boxes;
        float x,y,w,h;
        int id;
        int i;
        int sub = 0;
        /*
        *原始object:2 0.36666666666666664 0.42824074074074076 0.17083333333333334 0.25277777777777777
        * 变换后    2 0.3666666675 0.767904878  0.170833349 0.09741116
        * 
        */
        for (i = 0; i < count; ++i) {
            x =  boxes[i].x;
            y =  boxes[i].y;
            w =  boxes[i].w;
            h =  boxes[i].h;
            id = boxes[i].id;
    
            //如果宽高 小于  原图的千分之一,
            if ((w < .001 || h < .001)) {
                ++sub;
                continue;
            }
    
            truth[(i-sub)*5+0] = x;
            truth[(i-sub)*5+1] = y;
            truth[(i-sub)*5+2] = w;
            truth[(i-sub)*5+3] = h;
            truth[(i-sub)*5+4] = id;
        }
        
        free(boxes);
    }
    
    

    到此,我们发现 yolo 的 load_data :

    读取图片,resize 到网络宽高,然后抖动,移动整体画面曝光度色调等方式增广数据,并且 将 lable box 还原成 抖动过的图像数据中.并返回;

    附录:

    加载彩色图片:

    image load_image_color(char *filename, int w, int h)
    {
        return load_image(filename, w, h, 3);
    }
    image load_image(char *filename, int w, int h, int c)
    {
        /*
         * c 颜色通道
         * w 宽度
         * h 高度
         */
    #ifdef OPENCV
        image out = load_image_cv(filename, c);
    #else
        image out = load_image_stb(filename, c);
    #endif
    
        if((h && w) && (h != out.h || w != out.w)){
            //需要调整宽高
            image resized = resize_image(out, w, h);
            free_image(out);
            out = resized;
        }
        return out;
    }
    

    darknet 所用数据 结构体:

    //matrix.h
    //这里rows是一次加载到内存中的样本的个数(batch*net.subdivisions),cols就是样本的维度,**vals指向的是样本的值
    typedef struct matrix{
        int  rows, cols;
        float **vals;
    } matrix;
    
    typedef struct {
        int w;
        int h;
        int c;
        float *data;
    } image;
    
    
    typedef struct{
    //图像 宽高
        int w, h;
    //X 为图像内容,y 为label 
        matrix X;
        matrix y;
        int shallow;
        int *num_boxes;
        box **boxes;
    } data;
    
    
    结构体:
    ```c
    typedef struct load_args{
        int threads;
        char **paths;
        char *path;
        int n;
        int m;
        char **labels;
        int h;
        int w;
        int out_w;
        int out_h;
        int nh;
        int nw;
        int num_boxes;
        int min, max, size;
        int classes;
        int background;
        int scale;
        int center;
        int coords;
        float jitter;
        float angle;
        float aspect;
        float saturation;
        float exposure;
        float hue;
        data *d;
        image *im;
        image *resized;
        data_type type;
        tree *hierarchy;
    } load_args;
    
    
    
    
    

    相关文章

      网友评论

          本文标题:yolo v3 源码阅读(2):数据格式与加载

          本文链接:https://www.haomeiwen.com/subject/pqhfzftx.html