美文网首页pytorch学习笔记深度学习目标跟踪&&目标检测
pytroch学习(二十五)—目标检测(数据集制作)

pytroch学习(二十五)—目标检测(数据集制作)

作者: 侠之大者_7d3f | 来源:发表于2019-01-23 12:06 被阅读0次

前言

之前,测试通了pytroch版的yolo-v2/v3, ssd-mobilenetv1/v2目标检测代码。 相对于测试,如何用自己的数据训练一个目标检测模型才更令人兴奋。俗话曰:兵马未动,粮草先行, 在训练之前,首先需要准备好训练数据。

在许多例子中,一般都用VOC, COCO格式的数据集进行训练和测试。对于我们自己的数据,一般不是VOC/COCO格式的数据,所以一个比较笨的方法就是写一个脚本进行数据格式转换,再不济可以手动创建文件夹,直接把相应的数据复制到制定的目录,这样很麻烦。麻烦的对方主要在于:1. VOC中标签都是1张图像对应一个xml文件, xml结构数据本身相对解析麻烦,不如JSON,YAML轻巧。 2. 电脑中需要将原始数据复制2份,一份用作VOC格式数据, 另一份是原始数据。

下面将直接使用原始数据,使用pytroch提供的类对数据进行简单封装,实现数据集的索引和读取。

快速起见, 采用一个公开数据集,Wider Face, 这个数据集用于做人脸检测,训练集合包含12k的图像,而且提供人脸矩形框标签


开发环境

  • Ubuntu 18.04
  • pycharm
  • Anaconda3, python3.6
  • pytroch 1.0, torchvision

widerFace 人脸检测数据集

标签

image.png image.png image.png

训练集

简单起见,将wider_face_train_bbx_gt.txt 复制到训练集合所在路径, images文件夹包含图像。

image.png image.png
  • 代码
    在pytroch中,数据集定义很简单,按照pytroch提供的套路就可以。 一般的, 首先定义一个类继承troch.utils.data.Dataset, 然后override __len()__, __getitem()__ 方法。
  1. __len()__ : 返回数据集容量大小

  2. __getitem()__: 返回数据集迭代时候每一个样本及其标签数据。

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transfroms
import matplotlib.pyplot as plt
import os
import PIL.Image as Image
import PIL
import cv2
import numpy as np

class WiderFaceDataset(Dataset):
    def __init__(self, images_folder, ground_truth_file, transform=None, target_transform=None):
        super(WiderFaceDataset, self).__init__()
        self.images_folder = images_folder
        self.ground_truth_file = ground_truth_file
        self.images_name_list = []
        self.ground_truth = []
        with open(ground_truth_file, 'r') as f:
            for i in f:
                self.images_name_list.append(i.rstrip())
                self.ground_truth.append(i.rstrip())

        self.images_name_list = list(filter(lambda x: x.endswith('.jpg') or x.endswith('.bmp'),
                                       self.images_name_list))

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images_name_list)

    def __getitem__(self, index):
        image_name = self.images_name_list[index]
        # 查找文件名
        loc = self._search(image_name)
        # 解析人脸个数
        face_nums = int(self.ground_truth[loc + 1])
        # 读取矩形框
        rects = []
        for i in range(loc + 2, loc + 2 + face_nums):
            line = self.ground_truth[i]
            x, y, w, h = line.split(' ')[:4]
            x, y, w, h = list(map(lambda k: int(k), [x, y, w, h]))
            rects.append([x, y, w, h])

        # 图像
        image = PIL.Image.open(os.path.join(self.images_folder, image_name))

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            rects = list(map(lambda x: self.target_transform(x), rects))

        return {'image': image, 'label': rects, 'image_name': os.path.join(self.images_folder, image_name)}

    def _search(self, image_name):
        for i, line in enumerate(self.ground_truth):
            if image_name == line:
                return i


if __name__ == '__main__':
   images_folder = '/media/weipenghui/Extra/WiderFace/WIDER_train/images'
   ground_truth_file = open('/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt', 'r')

   dataset = WiderFaceDataset(images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
                              ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt',
                              transform=transfroms.ToTensor(),
                              target_transform=lambda x: torch.tensor(x))

   var = next(iter(dataset))
   image_transformed = var['image']
   label_transformed = var['label']
   image_name = var['image_name']
   #plt.figure()
   image_transformed = image_transformed.numpy().transpose((1, 2, 0))
   image_transformed = np.floor(image_transformed * 255).astype(np.uint8)
   image = cv2.imread(image_name)
   for rect in label_transformed:
       x, y, w, h = rect
       x, y, w, h = list(map(lambda k: k.item(), [x, y, w, h]))
       cv2.rectangle(image, pt1=(x, y), pt2=(x + w, y + h),color=(255,0,0))

   cv2.imshow('image',image)
   cv2.waitKey(0)
   plt.imshow(image_transformed)
   plt.show()

   # for i, sample in enumerate(dataset):
   #     print(i, sample['image'])
   # 
   # print(len(dataset))


image.png image.png image.png image.png

相关文章

网友评论

    本文标题:pytroch学习(二十五)—目标检测(数据集制作)

    本文链接:https://www.haomeiwen.com/subject/qqjjjqtx.html