美文网首页
语义分割-筛选ADE数据并重新赋值标签

语义分割-筛选ADE数据并重新赋值标签

作者: su945 | 来源:发表于2020-05-17 23:17 被阅读0次

    筛选ADE数据并重新赋值标签

    ADE中标签类共有150个,目前项目主要是针对室内场景的语义分割。因此需要筛选特定的类别,同时将图像中标签顺序重新赋值。原始label图像中,像素值代表类别顺序。因此,需要根据选定的类别重新赋值,其余类别的像素值设为0即背景类。

    #-*- coding:utf-8 -*-
    # author:suyuan
    # datetime:2020/5/6 上午10:18
    # software: PyCharm
    
    import glob
    import os
    from PIL import Image
    import cv2 as cv
    import numpy as np
    import json
    
    ####解析odgt数据
    def parse_input_list(odgt, max_sample=-1, start_idx=-1, end_idx=-1):
        # 判断odgt是否为list类型
        #list_sample = []
        if isinstance(odgt, list):
            list_sample = odgt
        elif isinstance(odgt, str):
            list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
    
        if max_sample > 0:
           list_sample =list_sample[0:max_sample]
        if start_idx >= 0 and end_idx >= 0:  # divide file list
            list_sample = list_sample[start_idx:end_idx]
        # 样本数量
        num_sample = len(list_sample)
        assert num_sample > 0
        print('# samples: {}'.format(num_sample))
        return list_sample
    
    ###筛选出一些特定的类别
    def select_ADEcategory(lable_list,src_lable_path,dst_lable_path):
        src_lable_list = glob.glob(os.path.join(src_lable_path,'*.png'))
        for label_file in src_lable_list:
            img_name = label_file.split('/')[-1]
            img_id = img_name.split('.')[0]
    
            label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
            h = label.shape[0]
            w = label.shape[1]
    
            # new
            rewrite_label_file_name = img_id + '.png'
            rewrite_lable_file_path = os.path.join(dst_lable_path, rewrite_label_file_name)
            # rewrite_lable = cv.imread(rewrite_lable_file_path, cv.IMREAD_GRAYSCALE)
    
            for i in range(h):
                for j in range(w):
                    lable_pixel = label[i, j]
                    # 更新id信息
                    for label_id  in lable_list:
                        if label[i, j] == label_id:
                           label[i, j] = 0
                    #label[i, j] = lable_list[lable_pixel]
            # 重新保存更改过标签的label图像
            cv.imwrite(rewrite_lable_file_path, label)
            print(rewrite_lable_file_path)
    
    ###统计类别图像中类别的个数
    #根据odgt文件进行读取
    def count_category_odgt(root_dataset,odgt_file):
        list_sample = parse_input_list(odgt_file)
        count_list = np.zeros([151], dtype=np.int32)
        for sample in list_sample:
            segm_path = os.path.join(root_dataset, sample['fpath_segm'])
            label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
            h = label.shape[0]
            w = label.shape[1]
    
            hest = np.zeros([151], dtype=np.int32)
            for i in range(h):
                for j in range(w):
                    lable_pixel = label[i, j]
                    # 统计
                    if hest[lable_pixel] == 0:
                        hest[lable_pixel] = 1
            select = np.where(hest == 1)
            # 根据类别进行统计
            for index in select:
                count_list[index] += 1
            # 打印出类别个数
        count = np.sum(count_list[1:])
        for i in range(151):
            if i != 0 :
                print('类别:', i, ',数量:', count_list[i],',占比:',count_list[i]/count)
    
    
    
    
    #直接读取路径下的图片
    def count_category(src_lable_path):
        src_lable_list = glob.glob(os.path.join(src_lable_path, '*.png'))
    
        count_list = np.zeros([151], dtype=np.int32)
        for label_file in src_lable_list:
            img_name = label_file.split('/')[-1]
            img_id = img_name.split('.')[0]
    
            label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
            h = label.shape[0]
            w = label.shape[1]
    
            hest = np.zeros([151], dtype=np.int32)
            for i in range(h):
                for j in range(w):
                    lable_pixel = label[i, j]
                    #统计
                    if hest[lable_pixel] == 0:
                        hest[lable_pixel] = 1
            select = np.where(hest==1)
            #根据类别进行统计
            for index in select:
                count_list[index] +=1
        #打印出类别个数
        for i in range(151):
            print('类别:',i,',数量:',count_list[i])
    
    
    def select_20_category(odgt_file,save_path,select_id):
        print('选择类别数:',len(select_id))
        list_sample = parse_input_list(odgt_file)
    
        for sample in list_sample:
            segm_path = os.path.join(root_dataset, sample['fpath_segm'])
            label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
            h = label.shape[0]
            w = label.shape[1]
    
            for i in range(h):
                for j in range(w):
                    lable_pixel = label[i, j]
                    #
                    if lable_pixel in select_id:
                        for times, id in  enumerate(select_id) :
                            if id == lable_pixel:
                                label[i, j] = times+1 ;
                    else:
                        label[i, j] = 0;
    
            img_name = segm_path.split('/')[-1]
            #img_id = img_name.split('.')[0]
            img_save_path = os.path.join(save_path,img_name)
            #重新保存
            cv.imwrite(img_save_path, label)
    
    
    
    
    if __name__ == '__main__':
        lable_list = [1,13,136]
    
    
        root_path = "/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016"
        train_label_path = os.path.join(root_path,"annotations/training")
        select_train_lable_path = os.path.join(root_path,"annotations_select0506/training")
        #select_ADEcategory(lable_list, train_label_path, select_train_lable_path)
    
        val_lable_path = os.path.join(root_path,"annotations/validation")
        select_val_lable_path = os.path.join(root_path,"annotations_select0506/validation")
        #select_ADEcategory(lable_list, val_lable_path, select_val_lable_path)
    
    
        #统计类别数量
        src_lable_path  = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/test'
        #直接读取路径下图像
        #count_category(src_lable_path)
    
        root_dataset = '/media/suyuan/U/data/segment/ADE20K'
        odgt_file = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/validation_select.odgt'
        #使用ogdt
        #count_category_odgt(root_dataset,odgt_file)
    
        #类别选择
        save_path = '/media/suyuan/U/data/segment/ADE20K/training'
        select_id = [8, 9, 11, 13, 15, 16, 18, 19, 20, 23, 24, 25, 28, 34, 43, 45, 48, 58, 83, 90]
        select_20_category(odgt_file, save_path, select_id)
    
        #统计类别数量
        file = np.loadtxt(os.path.join(root_dataset,'ADEChallengeData2016/temo.txt'))
        count = np.sum(file[1:])
        list_id =[1,8,9,10,11,13,15,16,18,19,20,23,24,25,28,34,38,45,48,58,64,82,83,88,90,146]
        pr = 0
        for id in list_id:
            percent = 100*file[id]/9755
            pr  += percent
            print('id:',id,',占比',percent,'%')
        print(pr,'%')
    
    

    相关文章

      网友评论

          本文标题:语义分割-筛选ADE数据并重新赋值标签

          本文链接:https://www.haomeiwen.com/subject/etkoghtx.html