美文网首页pytorch学习笔记深度学习目标跟踪&&目标检测
pytroch学习(二十五)—目标检测(数据集制作)

pytroch学习(二十五)—目标检测(数据集制作)

作者: 侠之大者_7d3f | 来源:发表于2019-01-23 12:06 被阅读0次

    前言

    之前,测试通了pytroch版的yolo-v2/v3, ssd-mobilenetv1/v2目标检测代码。 相对于测试,如何用自己的数据训练一个目标检测模型才更令人兴奋。俗话曰:兵马未动,粮草先行, 在训练之前,首先需要准备好训练数据。

    在许多例子中,一般都用VOC, COCO格式的数据集进行训练和测试。对于我们自己的数据,一般不是VOC/COCO格式的数据,所以一个比较笨的方法就是写一个脚本进行数据格式转换,再不济可以手动创建文件夹,直接把相应的数据复制到制定的目录,这样很麻烦。麻烦的对方主要在于:1. VOC中标签都是1张图像对应一个xml文件, xml结构数据本身相对解析麻烦,不如JSON,YAML轻巧。 2. 电脑中需要将原始数据复制2份,一份用作VOC格式数据, 另一份是原始数据。

    下面将直接使用原始数据,使用pytroch提供的类对数据进行简单封装,实现数据集的索引和读取。

    快速起见, 采用一个公开数据集,Wider Face, 这个数据集用于做人脸检测,训练集合包含12k的图像,而且提供人脸矩形框标签


    开发环境

    • Ubuntu 18.04
    • pycharm
    • Anaconda3, python3.6
    • pytroch 1.0, torchvision

    widerFace 人脸检测数据集

    标签

    image.png image.png image.png

    训练集

    简单起见,将wider_face_train_bbx_gt.txt 复制到训练集合所在路径, images文件夹包含图像。

    image.png image.png
    • 代码
      在pytroch中,数据集定义很简单,按照pytroch提供的套路就可以。 一般的, 首先定义一个类继承troch.utils.data.Dataset, 然后override __len()__, __getitem()__ 方法。
    1. __len()__ : 返回数据集容量大小

    2. __getitem()__: 返回数据集迭代时候每一个样本及其标签数据。

    import torch
    from torch.utils.data import Dataset
    import torchvision.transforms as transfroms
    import matplotlib.pyplot as plt
    import os
    import PIL.Image as Image
    import PIL
    import cv2
    import numpy as np
    
    class WiderFaceDataset(Dataset):
        def __init__(self, images_folder, ground_truth_file, transform=None, target_transform=None):
            super(WiderFaceDataset, self).__init__()
            self.images_folder = images_folder
            self.ground_truth_file = ground_truth_file
            self.images_name_list = []
            self.ground_truth = []
            with open(ground_truth_file, 'r') as f:
                for i in f:
                    self.images_name_list.append(i.rstrip())
                    self.ground_truth.append(i.rstrip())
    
            self.images_name_list = list(filter(lambda x: x.endswith('.jpg') or x.endswith('.bmp'),
                                           self.images_name_list))
    
            self.transform = transform
            self.target_transform = target_transform
    
        def __len__(self):
            return len(self.images_name_list)
    
        def __getitem__(self, index):
            image_name = self.images_name_list[index]
            # 查找文件名
            loc = self._search(image_name)
            # 解析人脸个数
            face_nums = int(self.ground_truth[loc + 1])
            # 读取矩形框
            rects = []
            for i in range(loc + 2, loc + 2 + face_nums):
                line = self.ground_truth[i]
                x, y, w, h = line.split(' ')[:4]
                x, y, w, h = list(map(lambda k: int(k), [x, y, w, h]))
                rects.append([x, y, w, h])
    
            # 图像
            image = PIL.Image.open(os.path.join(self.images_folder, image_name))
    
            if self.transform:
                image = self.transform(image)
    
            if self.target_transform:
                rects = list(map(lambda x: self.target_transform(x), rects))
    
            return {'image': image, 'label': rects, 'image_name': os.path.join(self.images_folder, image_name)}
    
        def _search(self, image_name):
            for i, line in enumerate(self.ground_truth):
                if image_name == line:
                    return i
    
    
    if __name__ == '__main__':
       images_folder = '/media/weipenghui/Extra/WiderFace/WIDER_train/images'
       ground_truth_file = open('/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt', 'r')
    
       dataset = WiderFaceDataset(images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
                                  ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt',
                                  transform=transfroms.ToTensor(),
                                  target_transform=lambda x: torch.tensor(x))
    
       var = next(iter(dataset))
       image_transformed = var['image']
       label_transformed = var['label']
       image_name = var['image_name']
       #plt.figure()
       image_transformed = image_transformed.numpy().transpose((1, 2, 0))
       image_transformed = np.floor(image_transformed * 255).astype(np.uint8)
       image = cv2.imread(image_name)
       for rect in label_transformed:
           x, y, w, h = rect
           x, y, w, h = list(map(lambda k: k.item(), [x, y, w, h]))
           cv2.rectangle(image, pt1=(x, y), pt2=(x + w, y + h),color=(255,0,0))
    
       cv2.imshow('image',image)
       cv2.waitKey(0)
       plt.imshow(image_transformed)
       plt.show()
    
       # for i, sample in enumerate(dataset):
       #     print(i, sample['image'])
       # 
       # print(len(dataset))
    
    
    
    image.png image.png image.png image.png

    相关文章

      网友评论

        本文标题:pytroch学习(二十五)—目标检测(数据集制作)

        本文链接:https://www.haomeiwen.com/subject/qqjjjqtx.html