前言
之前,测试通了pytroch版的yolo-v2/v3, ssd-mobilenetv1/v2目标检测代码。 相对于测试,如何用自己的数据训练一个目标检测模型才更令人兴奋。俗话曰:兵马未动,粮草先行, 在训练之前,首先需要准备好训练数据。
在许多例子中,一般都用VOC, COCO格式的数据集进行训练和测试。对于我们自己的数据,一般不是VOC/COCO格式的数据,所以一个比较笨的方法就是写一个脚本进行数据格式转换,再不济可以手动创建文件夹,直接把相应的数据复制到制定的目录,这样很麻烦。麻烦的对方主要在于:1. VOC中标签都是1张图像对应一个xml文件, xml结构数据本身相对解析麻烦,不如JSON,YAML轻巧。 2. 电脑中需要将原始数据复制2份,一份用作VOC格式数据, 另一份是原始数据。
下面将直接使用原始数据,使用pytroch提供的类对数据进行简单封装,实现数据集的索引和读取。
快速起见, 采用一个公开数据集,Wider Face, 这个数据集用于做人脸检测,训练集合包含12k的图像,而且提供人脸矩形框标签
开发环境
- Ubuntu 18.04
- pycharm
- Anaconda3, python3.6
- pytroch 1.0, torchvision
widerFace 人脸检测数据集
-
Wider Face
-
BaiDuYun链接: https://pan.baidu.com/s/1HjEsIzkQtS5ea2mOVoRFtA
image.png
标签
训练集
简单起见,将wider_face_train_bbx_gt.txt 复制到训练集合所在路径, images文件夹包含图像。
image.png image.png- 代码
在pytroch中,数据集定义很简单,按照pytroch提供的套路就可以。 一般的, 首先定义一个类继承troch.utils.data.Dataset
, 然后override__len()__
,__getitem()__
方法。
-
__len()__
: 返回数据集容量大小 -
__getitem()__
: 返回数据集迭代时候每一个样本及其标签数据。
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transfroms
import matplotlib.pyplot as plt
import os
import PIL.Image as Image
import PIL
import cv2
import numpy as np
class WiderFaceDataset(Dataset):
def __init__(self, images_folder, ground_truth_file, transform=None, target_transform=None):
super(WiderFaceDataset, self).__init__()
self.images_folder = images_folder
self.ground_truth_file = ground_truth_file
self.images_name_list = []
self.ground_truth = []
with open(ground_truth_file, 'r') as f:
for i in f:
self.images_name_list.append(i.rstrip())
self.ground_truth.append(i.rstrip())
self.images_name_list = list(filter(lambda x: x.endswith('.jpg') or x.endswith('.bmp'),
self.images_name_list))
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.images_name_list)
def __getitem__(self, index):
image_name = self.images_name_list[index]
# 查找文件名
loc = self._search(image_name)
# 解析人脸个数
face_nums = int(self.ground_truth[loc + 1])
# 读取矩形框
rects = []
for i in range(loc + 2, loc + 2 + face_nums):
line = self.ground_truth[i]
x, y, w, h = line.split(' ')[:4]
x, y, w, h = list(map(lambda k: int(k), [x, y, w, h]))
rects.append([x, y, w, h])
# 图像
image = PIL.Image.open(os.path.join(self.images_folder, image_name))
if self.transform:
image = self.transform(image)
if self.target_transform:
rects = list(map(lambda x: self.target_transform(x), rects))
return {'image': image, 'label': rects, 'image_name': os.path.join(self.images_folder, image_name)}
def _search(self, image_name):
for i, line in enumerate(self.ground_truth):
if image_name == line:
return i
if __name__ == '__main__':
images_folder = '/media/weipenghui/Extra/WiderFace/WIDER_train/images'
ground_truth_file = open('/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt', 'r')
dataset = WiderFaceDataset(images_folder='/media/weipenghui/Extra/WiderFace/WIDER_train/images',
ground_truth_file='/media/weipenghui/Extra/WiderFace/WIDER_train/wider_face_train_bbx_gt.txt',
transform=transfroms.ToTensor(),
target_transform=lambda x: torch.tensor(x))
var = next(iter(dataset))
image_transformed = var['image']
label_transformed = var['label']
image_name = var['image_name']
#plt.figure()
image_transformed = image_transformed.numpy().transpose((1, 2, 0))
image_transformed = np.floor(image_transformed * 255).astype(np.uint8)
image = cv2.imread(image_name)
for rect in label_transformed:
x, y, w, h = rect
x, y, w, h = list(map(lambda k: k.item(), [x, y, w, h]))
cv2.rectangle(image, pt1=(x, y), pt2=(x + w, y + h),color=(255,0,0))
cv2.imshow('image',image)
cv2.waitKey(0)
plt.imshow(image_transformed)
plt.show()
# for i, sample in enumerate(dataset):
# print(i, sample['image'])
#
# print(len(dataset))
image.png
image.png
image.png
image.png
网友评论