1.程序
pikaqiu.py
import matplotlib.pyplot as plt
import d2lzh as d2l
from mxnet import gluon, image
from mxnet.gluon import utils as gutils
import os
def _download_pikachu(data_dir):
root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'
'gluon/dataset/pikachu/')
dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}
for k, v in dataset.items():
gutils.download(root_url + k, os.path.join(data_dir, k), sha1_hash=v)
# 本函数已保存在d2lzh包中方便以后使用
def load_data_pikachu(batch_size, edge_size=256): # edge_size:输出图像的宽和高
data_dir = '../data/pikachu'
#_download_pikachu(data_dir)
train_iter = image.ImageDetIter(
path_imgrec=os.path.join(data_dir, 'train.rec'),
path_imgidx=os.path.join(data_dir, 'train.idx'),
batch_size=batch_size,
data_shape=(3, edge_size, edge_size), # 输出图像的形状
shuffle=True, # 以随机顺序读取数据集
rand_crop=1, # 随机裁剪的概率为1
min_object_covered=0.95, max_attempts=200)
val_iter = image.ImageDetIter(
path_imgrec=os.path.join(data_dir, 'val.rec'), batch_size=batch_size,
data_shape=(3, edge_size, edge_size), shuffle=False)
return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_pikachu(batch_size, edge_size)
batch = train_iter.next()
batch.data[0].shape, batch.label[0].shape
imgs = (batch.data[0][0:10].transpose((0, 2, 3, 1))) / 255
axes = d2l.show_images(imgs, 2, 5).flatten()
for ax, label in zip(axes, batch.label[0][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()
在目标检测领域并没有类似MNIST或Fashion-MNIST那样的小数据集。
为了快速测试模型,我们合成了一个小的数据集。
我们首先使用一个开源的皮卡丘3D模型生成了1,000张不同角度和大小的皮卡丘图像。然后我们收集了一系列背景图像,并在每张图的随机位置放置一张随机的皮卡丘图像。
我们使用MXNet提供的im2rec工具将图像转换成二进制的RecordIO格式 [1]。该格式既可以降低数据集在磁盘上的存储开销,又能提高读取效率。如果想了解更多的图像读取方法,可以查阅GluonCV工具包的文档 [2]。
image.png
image.png
2.将皮卡丘变成jpg图片
安装opencv-python。
import time
from matplotlib import pyplot as plt
import numpy as np
import mxnet as mx
from mxnet import autograd, gluon
import gluoncv as gcv
from gluoncv.utils import download, viz
import cv2
url = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/train.rec'
idx_url = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/train.idx'
download(url, path='pikachu_train.rec', overwrite=False)
download(idx_url, path='pikachu_train.idx', overwrite=False)
dataset = gcv.data.RecordFileDetection('pikachu_train.rec')
classes = ['pikachu'] # only one foreground class here
for i in range(100):
image, label = dataset[i]
print(image.shape)
print('label:', label)
img=cv2.cvtColor(image.asnumpy(),cv2.COLOR_RGB2BGR)
cv2.imwrite(str(i)+".jpg", img)
# display image and label
#ax = viz.plot_bbox(image, bboxes=label[:, :4], labels=label[:, 4:5], class_names=classes)
#plt.show()

网友评论