美文网首页
物体检测之加载数据集和画框

物体检测之加载数据集和画框

作者: 小黄不头秃 | 来源:发表于2022-09-19 01:21 被阅读0次

    (一)物体检测

    前面咱们讨论的都是图片分类的问题,他注重的是图面中的主体,而对于其他的物体,就不会去关注。那么如果画面中有一只狗和一只猫,我们的模型该如何进行分类呢?其实我们更希望他能够做到的是,能发现图里面有一只狗和一只猫并且能够知道它们的位置,这就是物体检测。

    (1)边缘框

    在目标检测中,我们通常使用边界框(bounding box)来描述对象的空间位置。
    边界框是矩形的,由矩形左上角的以及右下角的xy坐标决定。
    另一种常用的边界框表示方法是边界框中心的(x, y)轴坐标以及框的宽度和高度。

    有两种写法可以将一个物体框出来,

    • (左上下,左上y,右上x,右上y)
    • (左上下,左上y,宽,高)
    (2)目标检测数据集

    不能和以前一样一个文件夹里面放一类图片,我们现在可能需要一个单独的文件用来存储图片的标签。例如:(图片名称,物体类别,边缘框)

    COCO数据集,一共有80类物体,330K的图片,1.5M个物体。

    (二)代码实现

    画一个框和数据集
    后面的数据集可能会用到一个香蕉集
    下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip

    %matplotlib inline
    import torch
    from d2l import torch as d2l
    import numpy as np
    import matplotlib.pyplot as plt
    
    d2l.set_figsize()
    img = d2l.plt.imread('../img/catdog.jpg')
    d2l.plt.imshow(img)
    
    x = torch.arange(5)
    x = (x,x,x)
    print(torch.stack(x, axis=-1))
    print(torch.stack(x, axis=0))
    
    def box_corner_to_center(boxes):
        """从(左上,右下)转换到(中间,宽高)"""
        x1,y1,x2,y2 = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
        cx = (x1+x2)/2
        cy = (y1+y2)/2
        w = x2 - x1
        h = y2 - y1
        boxes = torch.stack((cx, cy, w, h), axis=-1)
        return boxes
    
    def box_center_to_corner(boxes):
        cx,cy,w,h = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
        x1 = cx-0.5*w
        x2 = cx+0.5*w
        y1 = cy-0.5*h
        y2 = cy+0.5*h
        boxes = torch.stack((x1, y1, x2, y2), axis=-1)
        return boxes
    
    dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
    boxes = torch.tensor((dog_bbox, cat_bbox))
    box_center_to_corner(box_corner_to_center(boxes)) == boxes
    
    def bbox_to_rect(bbox, color):
        return d2l.plt.Rectangle(
            xy=(bbox[0],bbox[1]),
            width=bbox[2]-bbox[0],
            height=bbox[3]-bbox[1],
            fill=False,
            edgecolor= color,
            linewidth=2
        )
    
    fig = d2l.plt.imshow(img)
    fig.axes.add_patch(bbox_to_rect(dog_bbox,"blue"))
    fig.axes.add_patch(bbox_to_rect(cat_bbox,"red"))
    

    # 目标检测数据集
    # 这个数据集叫香蕉集,用来检测香蕉
    # 可以手动下载,也可以使用代码下载
    # 下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
    %matplotlib inline
    import os 
    import pandas as pd
    import torch
    import torchvision
    from d2l import torch as d2l
    from PIL import Image
    
    #@save
    d2l.DATA_HUB['banana-detection'] = (
        d2l.DATA_URL + 'banana-detection.zip',
        '5de26c8fce5ccdea9f91267273464dc968d20d72')
    
    def read_data_bananas(is_train=True):
        """读取香蕉检测数据集中的图像和标签"""
        data_dir = "../data/banana-detection/"
        csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                                 else 'bananas_val', 'label.csv')
        csv_data = pd.read_csv(csv_fname)
        csv_data = csv_data.set_index('img_name')
        images, targets = [], []
        for img_name, target in csv_data.iterrows():
            images.append(torchvision.io.read_image(
                os.path.join(data_dir, 'bananas_train' if is_train else
                             'bananas_val', 'images', f'{img_name}')))
            # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
            # 其中所有图像都具有相同的香蕉类(索引为0)
            targets.append(list(target))
        # print(type(images[0]))
        # print(type(targets[0]))
        return images, torch.tensor(targets).unsqueeze(1) / 256
    
    class BananasDataset(torch.utils.data.Dataset):
        """一个用于加载香蕉检测数据集的自定义数据集"""
        def __init__(self, is_train):
            self.features, self.labels = read_data_bananas(is_train)
            print('read ' + str(len(self.features)) + (f' training examples' if
                  is_train else f' validation examples'))
    
        def __getitem__(self, idx):
            return (self.features[idx].float(), self.labels[idx])
    
        def __len__(self):
            return len(self.features)
    
    def load_data_bananas(batch_size):
        """加载香蕉检测数据集"""
        train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                                 batch_size, shuffle=True)
        val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                               batch_size)
        return train_iter, val_iter
    
    batch_size, edge_size = 32, 256
    train_iter, _ = load_data_bananas(batch_size)
    batch = next(iter(train_iter))
    batch[0].shape, batch[1].shape
    
    # 把通道数移到后面去
    imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
    axes = d2l.show_images(imgs, 2, 5, scale=2)
    for ax, label in zip(axes, batch[1][0:10]):
        d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
    

    以上都是书本上的写法,我一开始觉得还挺繁琐,于是自己又重新写了一下。结果发现还得是上面的这种写法效率高。

    # 简单的写法, 反面教材
    def read_csv(train=True):
        base_path = "../data/banana-detection/"
        if train: path = base_path + "bananas_train/label.csv"
        else: path = base_path + "bananas_val/label.csv"
        file = pd.read_csv(path)
        train_lable =file.set_index("img_name")
        features = []
        label = []
        for img_name, target in train_lable.iterrows():
            if train: features.append(torchvision.io.read_image(base_path+"bananas_train/images/"+img_name))
            else: features.append(torchvision.io.read_image(base_path+"bananas_val/images/"+img_name))
            label.append(list(target))
        # 所有时间都花在下面这个转换了,消耗的时间太多了,不推荐使用
        # 我尝试了使用其他的方法例如PIL的Image.open,结果是会消耗更多的时间
        # 我认为书本中的写法快的原因是重写了__len__()方法
        features = [item.numpy() for item in features]
        return (torch.tensor(features),torch.tensor(label).unsqueeze(1) / 256)
    
    def load_bananas(batch_size=32):
        train_data = read_csv(True)
        test_data = read_csv(False)
        train_dataset = torch.utils.data.TensorDataset(*train_data)
        test_dataset = torch.utils.data.TensorDataset(*test_data)
        return torch.utils.data.DataLoader(train_dataset,shuffle=True,batch_size=batch_size),torch.utils.data.DataLoader(test_dataset,shuffle=True,batch_size=batch_size)
    
    train_iter,test_iter = load_bananas()
    
    from PIL.ImageDraw import Draw as draw
    from PIL import Image
    batch = next(iter(train_iter))
    # 这里的permute和reshape并不一样,参数列表是矩阵的下标
    # 可以理解为将原来的(c,h,w)即(0,1,2)转变为了(h,w,c)即(1,2,0)
    imgs = batch[0][0].permute(1,2,0)
    
    fig = plt.imshow(imgs)
    print(batch[1][0][0])
    fig.axes.add_patch(bbox_to_rect((batch[1][0][0][1:5]*256),color="r"))
    
    

    相关文章

      网友评论

          本文标题:物体检测之加载数据集和画框

          本文链接:https://www.haomeiwen.com/subject/ywmonrtx.html