美文网首页
COCO数据集API(单类map,recall,统计)

COCO数据集API(单类map,recall,统计)

作者: carry_xz | 来源:发表于2020-12-02 14:09 被阅读0次

    安装

    https://github.com/cocodataset/cocoapi 下载源码
    cd到PythonAPI目录,执行python setup.py install

    类别统计,检测coco类型的数据集

    import pycocotools.coco as COCO
    
    def check_coco_json(annot_path):
        coco = COCO.COCO(annot_path)
        cats = coco.loadCats(coco.getCatIds())
        cat_nms=[cat['name'] for cat in cats]
        print('-' * 40)
        print('COCO categories number: {}'.format(len(cats)))
        print('COCO categories: {}'.format(' |'.join(cat_nms)))
    
        # 统计各类的图片数量和标注框数量
        catID = coco.getCatIds() # 顺序很重要
        imgID = coco.getImgIds()
        annID = coco.getAnnIds()
        for i,cid in enumerate(catID):
            if i+1 != cid:
                catID_e = True
                print('\nWaring: catgory id is not right cid:{} != {}'.format(cid,i))
        if len(set(imgID)) != len(imgID):
            print('Error: pic num is not equal to pic id numbers')
        else:
            print('\nAll img number:',len(imgID))
        if len(set(annID)) != len(annID):
            annID_e = True
            print('Error: annID repeat!')
        else:
            print('All ann label number:',len(annID))
        print('-'*40)
        print("{:<15} {:<6}  {:<10}  {:<8}".format('Catetories', 'image_num', 'target_num','class_id'))
        for cat_name in cat_nms:
            catId = coco.getCatIds(catNms=cat_name)
            if len(catId)>1:
                catId = [catId[-1]]
            imgId = coco.getImgIds(catIds=catId)
            annId = coco.getAnnIds(imgIds=imgId, catIds=catId, iscrowd=None)
            if len(imgId)==0:
                aa = 1
            print("{:<15} {:<6d}     {:<10d} {}".format(cat_name, len(imgId), len(annId),catId))
    
    if __name__ == "__main__":
        annot_path = '/home/kcadmin/user/20200106/aa.json'
        check_coco_json(annot_path)
    
    

    cocoAPI进行map,recall计算

    #-*-coding:utf-8-*-
    
    '''
    cocoAIP
    1. 获取类别名称对应的id
    2. 获取类别(单类)对应的图片id
    4. 获取多类对应的图片id
    3. 获取图片id对应的图片名称
    4. 通过gt的json文件和pred的json文件计算map
    5. 通过gt和pred计算每个类别的ap,recall
    '''
    
    import pickle, json
    import numpy as np
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    import itertools
    from terminaltables import AsciiTable
    
    def read_pickle(pkl):
        with open(pkl,'rb') as f:
            data = pickle.load(f)
        return data
    
    def read_json(json_pth):
        with open(json_pth,'r') as f:
            data = json.load(f)
        return data
    
    det_json = 'data/det.json'
    gt_json = 'data/gt.json'
    CLASSES = ('A','B','C')
    class_num = len(CLASSES)
    
    cocoGt = COCO(gt_json)
    
    # 获取所有图片的id
    all_id = cocoGt.getImgIds()
    
    # 获取类别(单类)对应的所有图片id
    catIds = cocoGt.getCatIds(catNms=list(CLASSES)) #,'long','meihua'
    
    # 获取多个类别对应的所有图片的id
    imgid_list = []
    for id_c in catIds:
        imgIds = cocoGt.getImgIds(catIds=id_c)
        imgid_list.extend(imgIds)
    imgid_list = list(set(imgid_list))
    
    # 获取图片id对应的图片路径
    img_info = cocoGt.load_imgs([imgid_list[0]])[0]
    fname = img_info['file_name']
    
    
    # 通过gt的json文件和pred的json文件计算map
    det_json = 'data/det.json'
    gt_json = 'data/gt.json'
    CLASSES = ('A','B','C')
    class_num = len(CLASSES)
    cocoGt = COCO(gt_json)
    cocoDt = cocoGt.loadRes(det_json)
    cocoEval = COCOeval(cocoGt, cocoDt, "bbox")
    cocoEval.params.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
    cocoEval.params.maxDets = list((100, 300, 1000))
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()
    
    # 过gt和pred计算每个类别的recall
    precisions = cocoEval.eval['precision'] # TP/(TP+FP) right/detection
    recalls = cocoEval.eval['recall'] # iou*class_num*Areas*Max_det TP/(TP+FN) right/gt
    print('\nIOU:{} MAP:{:.3f} Recall:{:.3f}'.format(cocoEval.params.iouThrs[0],np.mean(precisions[0, :, :, 0, -1]),np.mean(recalls[0, :, 0, -1])))
    # Compute per-category AP
    # from https://github.com/facebookresearch/detectron2/
    # precision: (iou, recall, cls, area range, max dets)
    results_per_category = []
    results_per_category_iou50 = []
    res_item = []
    for idx, catId in enumerate(range(class_num)):
        name = CLASSES[idx]
        precision = precisions[:, :, idx, 0, -1]
        precision_50 = precisions[0, :, idx, 0, -1]
        precision = precision[precision > -1]
    
        recall = recalls[ :, idx, 0, -1]
        recall_50 = recalls[0, idx, 0, -1]
        recall = recall[recall > -1]
    
        if precision.size:
            ap = np.mean(precision)
            ap_50 = np.mean(precision_50)
            rec = np.mean(recall)
            rec_50 = np.mean(recall_50)
        else:
            ap = float('nan')
            ap_50 = float('nan')
            rec = float('nan')
            rec_50 = float('nan')
        res_item = [f'{name}', f'{float(ap):0.3f}',f'{float(rec):0.3f}']
        results_per_category.append(res_item)
        res_item_50 = [f'{name}', f'{float(ap_50):0.3f}', f'{float(rec_50):0.3f}']
        results_per_category_iou50.append(res_item_50)
    
    item_num = len(res_item)
    num_columns = min(6, len(results_per_category) * item_num)
    results_flatten = list(
        itertools.chain(*results_per_category))
    headers = ['category', 'AP', 'Recall'] * (num_columns // item_num)
    results_2d = itertools.zip_longest(*[
        results_flatten[i::num_columns]
        for i in range(num_columns)
    ])
    table_data = [headers]
    table_data += [result for result in results_2d]
    table = AsciiTable(table_data)
    print('\n' + table.table)
    
    num_columns_50 = min(6, len(results_per_category_iou50) * item_num)
    results_flatten_50 = list(
        itertools.chain(*results_per_category_iou50))
    iou_ = cocoEval.params.iouThrs[0]
    headers_50 = ['category', 'AP{}'.format(iou_),'Recall{}'.format(iou_)] * (num_columns_50 // item_num)
    results_2d_50 = itertools.zip_longest(*[
        results_flatten_50[i::num_columns_50]
        for i in range(num_columns_50)
    ])
    
    table_data_50 = [headers_50]
    table_data_50 += [result for result in results_2d_50]
    table_50 = AsciiTable(table_data_50)
    print('\n' + table_50.table)
    

    相关文章

      网友评论

          本文标题:COCO数据集API(单类map,recall,统计)

          本文链接:https://www.haomeiwen.com/subject/iqblyctx.html