筛选ADE数据并重新赋值标签
ADE中标签类共有150个,目前项目主要是针对室内场景的语义分割。因此需要筛选特定的类别,同时将图像中标签顺序重新赋值。原始label图像中,像素值代表类别顺序。因此,需要根据选定的类别重新赋值,其余类别的像素值设为0即背景类。
#-*- coding:utf-8 -*-
# author:suyuan
# datetime:2020/5/6 上午10:18
# software: PyCharm
import glob
import os
from PIL import Image
import cv2 as cv
import numpy as np
import json
####解析odgt数据
def parse_input_list(odgt, max_sample=-1, start_idx=-1, end_idx=-1):
# 判断odgt是否为list类型
#list_sample = []
if isinstance(odgt, list):
list_sample = odgt
elif isinstance(odgt, str):
list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
if max_sample > 0:
list_sample =list_sample[0:max_sample]
if start_idx >= 0 and end_idx >= 0: # divide file list
list_sample = list_sample[start_idx:end_idx]
# 样本数量
num_sample = len(list_sample)
assert num_sample > 0
print('# samples: {}'.format(num_sample))
return list_sample
###筛选出一些特定的类别
def select_ADEcategory(lable_list,src_lable_path,dst_lable_path):
src_lable_list = glob.glob(os.path.join(src_lable_path,'*.png'))
for label_file in src_lable_list:
img_name = label_file.split('/')[-1]
img_id = img_name.split('.')[0]
label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
h = label.shape[0]
w = label.shape[1]
# new
rewrite_label_file_name = img_id + '.png'
rewrite_lable_file_path = os.path.join(dst_lable_path, rewrite_label_file_name)
# rewrite_lable = cv.imread(rewrite_lable_file_path, cv.IMREAD_GRAYSCALE)
for i in range(h):
for j in range(w):
lable_pixel = label[i, j]
# 更新id信息
for label_id in lable_list:
if label[i, j] == label_id:
label[i, j] = 0
#label[i, j] = lable_list[lable_pixel]
# 重新保存更改过标签的label图像
cv.imwrite(rewrite_lable_file_path, label)
print(rewrite_lable_file_path)
###统计类别图像中类别的个数
#根据odgt文件进行读取
def count_category_odgt(root_dataset,odgt_file):
list_sample = parse_input_list(odgt_file)
count_list = np.zeros([151], dtype=np.int32)
for sample in list_sample:
segm_path = os.path.join(root_dataset, sample['fpath_segm'])
label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
h = label.shape[0]
w = label.shape[1]
hest = np.zeros([151], dtype=np.int32)
for i in range(h):
for j in range(w):
lable_pixel = label[i, j]
# 统计
if hest[lable_pixel] == 0:
hest[lable_pixel] = 1
select = np.where(hest == 1)
# 根据类别进行统计
for index in select:
count_list[index] += 1
# 打印出类别个数
count = np.sum(count_list[1:])
for i in range(151):
if i != 0 :
print('类别:', i, ',数量:', count_list[i],',占比:',count_list[i]/count)
#直接读取路径下的图片
def count_category(src_lable_path):
src_lable_list = glob.glob(os.path.join(src_lable_path, '*.png'))
count_list = np.zeros([151], dtype=np.int32)
for label_file in src_lable_list:
img_name = label_file.split('/')[-1]
img_id = img_name.split('.')[0]
label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
h = label.shape[0]
w = label.shape[1]
hest = np.zeros([151], dtype=np.int32)
for i in range(h):
for j in range(w):
lable_pixel = label[i, j]
#统计
if hest[lable_pixel] == 0:
hest[lable_pixel] = 1
select = np.where(hest==1)
#根据类别进行统计
for index in select:
count_list[index] +=1
#打印出类别个数
for i in range(151):
print('类别:',i,',数量:',count_list[i])
def select_20_category(odgt_file,save_path,select_id):
print('选择类别数:',len(select_id))
list_sample = parse_input_list(odgt_file)
for sample in list_sample:
segm_path = os.path.join(root_dataset, sample['fpath_segm'])
label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
h = label.shape[0]
w = label.shape[1]
for i in range(h):
for j in range(w):
lable_pixel = label[i, j]
#
if lable_pixel in select_id:
for times, id in enumerate(select_id) :
if id == lable_pixel:
label[i, j] = times+1 ;
else:
label[i, j] = 0;
img_name = segm_path.split('/')[-1]
#img_id = img_name.split('.')[0]
img_save_path = os.path.join(save_path,img_name)
#重新保存
cv.imwrite(img_save_path, label)
if __name__ == '__main__':
lable_list = [1,13,136]
root_path = "/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016"
train_label_path = os.path.join(root_path,"annotations/training")
select_train_lable_path = os.path.join(root_path,"annotations_select0506/training")
#select_ADEcategory(lable_list, train_label_path, select_train_lable_path)
val_lable_path = os.path.join(root_path,"annotations/validation")
select_val_lable_path = os.path.join(root_path,"annotations_select0506/validation")
#select_ADEcategory(lable_list, val_lable_path, select_val_lable_path)
#统计类别数量
src_lable_path = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/test'
#直接读取路径下图像
#count_category(src_lable_path)
root_dataset = '/media/suyuan/U/data/segment/ADE20K'
odgt_file = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/validation_select.odgt'
#使用ogdt
#count_category_odgt(root_dataset,odgt_file)
#类别选择
save_path = '/media/suyuan/U/data/segment/ADE20K/training'
select_id = [8, 9, 11, 13, 15, 16, 18, 19, 20, 23, 24, 25, 28, 34, 43, 45, 48, 58, 83, 90]
select_20_category(odgt_file, save_path, select_id)
#统计类别数量
file = np.loadtxt(os.path.join(root_dataset,'ADEChallengeData2016/temo.txt'))
count = np.sum(file[1:])
list_id =[1,8,9,10,11,13,15,16,18,19,20,23,24,25,28,34,38,45,48,58,64,82,83,88,90,146]
pr = 0
for id in list_id:
percent = 100*file[id]/9755
pr += percent
print('id:',id,',占比',percent,'%')
print(pr,'%')
网友评论