(一)物体检测
前面咱们讨论的都是图片分类的问题,他注重的是图面中的主体,而对于其他的物体,就不会去关注。那么如果画面中有一只狗和一只猫,我们的模型该如何进行分类呢?其实我们更希望他能够做到的是,能发现图里面有一只狗和一只猫并且能够知道它们的位置,这就是物体检测。
(1)边缘框
在目标检测中,我们通常使用边界框(bounding box)来描述对象的空间位置。
边界框是矩形的,由矩形左上角的以及右下角的和坐标决定。
另一种常用的边界框表示方法是边界框中心的轴坐标以及框的宽度和高度。
有两种写法可以将一个物体框出来,
- (左上下,左上y,右上x,右上y)
- (左上下,左上y,宽,高)
(2)目标检测数据集
不能和以前一样一个文件夹里面放一类图片,我们现在可能需要一个单独的文件用来存储图片的标签。例如:(图片名称,物体类别,边缘框)
COCO数据集,一共有80类物体,330K的图片,1.5M个物体。
(二)代码实现
画一个框和数据集
后面的数据集可能会用到一个香蕉集
下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
%matplotlib inline
import torch
from d2l import torch as d2l
import numpy as np
import matplotlib.pyplot as plt
d2l.set_figsize()
img = d2l.plt.imread('../img/catdog.jpg')
d2l.plt.imshow(img)
x = torch.arange(5)
x = (x,x,x)
print(torch.stack(x, axis=-1))
print(torch.stack(x, axis=0))
def box_corner_to_center(boxes):
"""从(左上,右下)转换到(中间,宽高)"""
x1,y1,x2,y2 = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
cx = (x1+x2)/2
cy = (y1+y2)/2
w = x2 - x1
h = y2 - y1
boxes = torch.stack((cx, cy, w, h), axis=-1)
return boxes
def box_center_to_corner(boxes):
cx,cy,w,h = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
x1 = cx-0.5*w
x2 = cx+0.5*w
y1 = cy-0.5*h
y2 = cy+0.5*h
boxes = torch.stack((x1, y1, x2, y2), axis=-1)
return boxes
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
boxes = torch.tensor((dog_bbox, cat_bbox))
box_center_to_corner(box_corner_to_center(boxes)) == boxes
def bbox_to_rect(bbox, color):
return d2l.plt.Rectangle(
xy=(bbox[0],bbox[1]),
width=bbox[2]-bbox[0],
height=bbox[3]-bbox[1],
fill=False,
edgecolor= color,
linewidth=2
)
fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox,"blue"))
fig.axes.add_patch(bbox_to_rect(cat_bbox,"red"))
# 目标检测数据集
# 这个数据集叫香蕉集,用来检测香蕉
# 可以手动下载,也可以使用代码下载
# 下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from PIL import Image
#@save
d2l.DATA_HUB['banana-detection'] = (
d2l.DATA_URL + 'banana-detection.zip',
'5de26c8fce5ccdea9f91267273464dc968d20d72')
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = "../data/banana-detection/"
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
# 其中所有图像都具有相同的香蕉类(索引为0)
targets.append(list(target))
# print(type(images[0]))
# print(type(targets[0]))
return images, torch.tensor(targets).unsqueeze(1) / 256
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
def load_data_bananas(batch_size):
"""加载香蕉检测数据集"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
# 把通道数移到后面去
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
以上都是书本上的写法,我一开始觉得还挺繁琐,于是自己又重新写了一下。结果发现还得是上面的这种写法效率高。
# 简单的写法, 反面教材
def read_csv(train=True):
base_path = "../data/banana-detection/"
if train: path = base_path + "bananas_train/label.csv"
else: path = base_path + "bananas_val/label.csv"
file = pd.read_csv(path)
train_lable =file.set_index("img_name")
features = []
label = []
for img_name, target in train_lable.iterrows():
if train: features.append(torchvision.io.read_image(base_path+"bananas_train/images/"+img_name))
else: features.append(torchvision.io.read_image(base_path+"bananas_val/images/"+img_name))
label.append(list(target))
# 所有时间都花在下面这个转换了,消耗的时间太多了,不推荐使用
# 我尝试了使用其他的方法例如PIL的Image.open,结果是会消耗更多的时间
# 我认为书本中的写法快的原因是重写了__len__()方法
features = [item.numpy() for item in features]
return (torch.tensor(features),torch.tensor(label).unsqueeze(1) / 256)
def load_bananas(batch_size=32):
train_data = read_csv(True)
test_data = read_csv(False)
train_dataset = torch.utils.data.TensorDataset(*train_data)
test_dataset = torch.utils.data.TensorDataset(*test_data)
return torch.utils.data.DataLoader(train_dataset,shuffle=True,batch_size=batch_size),torch.utils.data.DataLoader(test_dataset,shuffle=True,batch_size=batch_size)
train_iter,test_iter = load_bananas()
from PIL.ImageDraw import Draw as draw
from PIL import Image
batch = next(iter(train_iter))
# 这里的permute和reshape并不一样,参数列表是矩阵的下标
# 可以理解为将原来的(c,h,w)即(0,1,2)转变为了(h,w,c)即(1,2,0)
imgs = batch[0][0].permute(1,2,0)
fig = plt.imshow(imgs)
print(batch[1][0][0])
fig.axes.add_patch(bbox_to_rect((batch[1][0][0][1:5]*256),color="r"))
网友评论