学习笔记15:第二种加载数据的方法 - pbc的成长之路 - 博客园 (cnblogs.com)
自定义创建DataSet子类
继承至data.Dataset父类,并且创建getitem,和len方法。
import torch
from torch.utils import data
from PIL import Image # pip install pillow
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob
# 取出所有路径
all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
# all_imgs_path[:5]
# ['./data/dataset2\\cloudy1.jpg',
# './data/dataset2\\cloudy10.jpg',
# './data/dataset2\\cloudy100.jpg',
# './data/dataset2\\cloudy101.jpg',
# './data/dataset2\wea\cloudy102.jpg']
# 获得所有标签
species = ['cloudy', 'rain', 'shine', 'sunrise']
all_labels = []
for img in all_imgs_path:
for i, c in enumerate(species):
if c in img:
all_labels.append(i)
# 定义数据集类
# 必须创建 __getitem__, __len__, __init__
class Mydataset(data.Dataset):
def __init__(self, root):
self.imgs_path = root # self.imgs_path 为所有图片的总路径
def __getitem__(self, index):
img_path = self.imgs_path[index] # img_path 为单个图片路径
return img_path
def __len__(self):
return len(self.imgs_path)
weather_dataset = Mydataset(all_imgs_path)
# print(len(weather_dataset)) # 1122 共1122张图片
# weather_dataset[1] # 可以对图片进行切片 './data/dataset2\\cloudy10.jpg'
wh_dl = torch.utils.data.DataLoader(weather_dataset,batch_size=4) # 创建dataloader
next(iter(wh_dl)) # 调用next方法,返回迭代一个批次的数据。 一个批次,一次返回4张图片的路径
# ['./data/dataset2\\cloudy1.jpg',
# './data/dataset2\\cloudy10.jpg',
# './data/dataset2\\cloudy100.jpg',
# './data/dataset2\\cloudy101.jpg']
举例说明,如何用
获取标签
# pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
# 使用glob.glob取出所有路径
all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
# 获得所有图片的标签
species = ['cloudy', 'rain', 'shine', 'sunrise']
species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
# species_to_dix # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
idx_to_species = dict((v,k) for k,v in species_to_idx.items()) # 对字典的items迭代,将变换后的结果复原
# idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
all_labels = []
for img in all_imgs_path: # 对所有的图片路径进行迭代。img代表其中一张图片的路径
for i, c in enumerate(species):
if c in img: # 如果类别(cloud/rain...)在路径里面
all_labels.append(i) # 将对应的类别序号append到all_labels 列表
绘制图片
import torch
from torch.utils import data
from PIL import Image # pip install pillow. python2中叫PIL
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob
# pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
# 使用glob.glob取出所有路径
all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
# 获得所有图片的标签
species = ['cloudy', 'rain', 'shine', 'sunrise']
species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
# species_to_dix # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
idx_to_species = dict((v,k) for k,v in species_to_idx.items()) # 对字典的items迭代,将变换后的结果复原
# idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
all_labels = []
for img in all_imgs_path: # 对所有的图片路径进行迭代。img代表其中一张图片的路径
for i, c in enumerate(species):
if c in img: # 如果类别(cloud/rain...)在路径里面
all_labels.append(i) # 将对应的类别序号append到all_labels 列表
# 使用transform对图片进行转换
transform = transforms.Compose([
transforms.Resize((96,96)),
transforms.ToTensor()
])
# 定义数据集类
# 必须创建 __getitem__, __len__, __init__
class Mydataset(data.Dataset):
def __init__(self, img_paths,labels,transform):
self.imgs = img_paths # self.imgs 为所有图片的总路径
self.labels = labels
self.transforms = transform
def __getitem__(self, index): # 加索引,返回的是图片这个对象。 先读取,再转换后返回。
img = self.imgs[index] # 对img进行切片
label = self.labels[index] # 对labels进行切片
pil_img = Image.open(img) #
pil_img = pil_img.convert('RGB') #防止图片中掺杂黑白图片。这一步建议加
data = self.transforms(pil_img) # 将每张图片进行Resize/To Tensor
return data,label
def __len__(self):
return len(self.imgs)
weather_dataset = Mydataset(all_imgs_path,all_labels,transform)
weather_dl = data.DataLoader(weather_dataset,batch_size=16,shuffle=True) #有几个计算核心,num_workers设置为几
imgs_1_batch,labels_1_batch = next(iter(weather_dl)) # 调用next方法,返回迭代一个批次的数据。
# imgs_1_batch.shape # torch.Size([16, 3, 96, 96]) batch_size=16, channel=3, w=h=96
# label_1_batch.shape # torch.Size([16])
绘制前6张
plt.figure(figsize=(12,8))
for i,(img,label) in enumerate(zip(imgs_1_batch[:6],labels_1_batch[:6])): # 绘制前6张
img = img.permute(1,2,0).numpy() # permute方法用于更改顺序,将channel放到后面
plt.subplot(2,3,i+1) # 2行3列,从第1开始
plt.title(idx_to_species.get(label.item())) # 获取id对应的分类
plt.imshow(img)
image.png
绘制后6张
plt.figure(figsize=(12,8))
for i,(img,label) in enumerate(zip(imgs_1_batch[-6:],labels_1_batch[-6:])): # 绘制后6张
img = img.permute(1,2,0).numpy() # permute方法用于更改顺序,将channel放到后面
plt.subplot(2,3,i+1) # 2行3列,从第1开始
plt.title(idx_to_species.get(label.item())) # 获取id对应的分类
plt.imshow(img)
image.png
网友评论