美文网首页
PyTrch深度学习简明实战14 - 图片输入的第二种方式(划分

PyTrch深度学习简明实战14 - 图片输入的第二种方式(划分

作者: 薛东弗斯 | 来源:发表于2023-03-26 21:08 被阅读0次

学习笔记15:第二种加载数据的方法 - pbc的成长之路 - 博客园 (cnblogs.com)

创建训练集与测试集

import torch
from torch.utils import data
from PIL import Image   #  pip install pillow. python2中叫PIL
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob

# pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
# 使用glob.glob取出所有路径
all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')

# 获得所有图片的标签
species = ['cloudy', 'rain', 'shine', 'sunrise']
species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
# species_to_dix    # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
idx_to_species = dict((v,k) for k,v in species_to_idx.items())   # 对字典的items迭代,将变换后的结果复原
# idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
all_labels = []
for img in all_imgs_path:           # 对所有的图片路径进行迭代。img代表其中一张图片的路径
    for i, c in enumerate(species): 
        if c in img:                # 如果类别(cloud/rain...)在路径里面
            all_labels.append(i)    # 将对应的类别序号append到all_labels 列表
            
# 使用transform对图片进行转换
transform = transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor()
])
            
# 定义数据集类
# 必须创建 __getitem__, __len__, __init__
class Mydataset(data.Dataset):
    def __init__(self, img_paths,labels,transform):
        self.imgs = img_paths    # self.imgs 为所有图片的总路径
        self.labels = labels
        self.transforms = transform
    def __getitem__(self, index):   # 加索引,返回的是图片这个对象。 先读取,再转换后返回。
        img = self.imgs[index]      # 对img进行切片
        label = self.labels[index]  # 对labels进行切片
        pil_img = Image.open(img)     # 
        pil_img = pil_img.convert('RGB')  #防止图片中掺杂黑白图片。这一步建议加
        data = self.transforms(pil_img)   # 将每张图片进行Resize/To Tensor
        return data,label
    def __len__(self):
        return len(self.imgs)
    
weather_dataset = Mydataset(all_imgs_path,all_labels,transform)
weather_dl = data.DataLoader(weather_dataset,batch_size=16,shuffle=True)   #有几个计算核心,num_workers设置为几
imgs_1_batch,labels_1_batch = next(iter(weather_dl))   # 调用next方法,返回迭代一个批次的数据。  
# imgs_1_batch.shape     # torch.Size([16, 3, 96, 96])   batch_size=16, channel=3, w=h=96
# label_1_batch.shape  # torch.Size([16])

index = np.random.permutation(len(all_imgs_path))
# index    # array([ 794,  909,  275, ...,  426, 1037,  857])   利用乱序的index对img和标签同时索引
all_imgs_path = np.array(all_imgs_path)[index]   # 这样,所有的图片按照index进行索引
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path)*0.8)   # 取出80%. 这样切分的前提是必须对数据做乱序
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydataset(test_imgs,test_labels,transform)

train_dl = data.DataLoader(train_ds,batch_size=16,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=16)

灵活的使用Dataset类构建输入

import torch
from torch.utils import data
from PIL import Image   #  pip install pillow. python2中叫PIL
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import glob

# pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
# 使用glob.glob取出所有路径
all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')

# 获得所有图片的标签
species = ['cloudy', 'rain', 'shine', 'sunrise']
species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
# species_to_dix    # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
idx_to_species = dict((v,k) for k,v in species_to_idx.items())   # 对字典的items迭代,将变换后的结果复原
# idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
all_labels = []
for img in all_imgs_path:           # 对所有的图片路径进行迭代。img代表其中一张图片的路径
    for i, c in enumerate(species): 
        if c in img:                # 如果类别(cloud/rain...)在路径里面
            all_labels.append(i)    # 将对应的类别序号append到all_labels 列表
            
# 使用transform对图片进行转换
transform = transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor()
])
            
# 定义数据集类
# 必须创建 __getitem__, __len__, __init__
class Mydataset(data.Dataset):
    def __init__(self, img_paths,labels,transform):
        self.imgs = img_paths    # self.imgs 为所有图片的总路径
        self.labels = labels
        self.transforms = transform
    def __getitem__(self, index):   # 加索引,返回的是图片这个对象。 先读取,再转换后返回。
        img = self.imgs[index]      # 对img进行切片
        label = self.labels[index]  # 对labels进行切片
        pil_img = Image.open(img)     # 
        pil_img = pil_img.convert('RGB')  #防止图片中掺杂黑白图片。这一步建议加
        data = self.transforms(pil_img)   # 将每张图片进行Resize/To Tensor
        return data,label
    def __len__(self):
        return len(self.imgs)
    
weather_dataset = Mydataset(all_imgs_path,all_labels,transform)
weather_dl = data.DataLoader(weather_dataset,batch_size=16,shuffle=True)   #有几个计算核心,num_workers设置为几
imgs_1_batch,labels_1_batch = next(iter(weather_dl))   # 调用next方法,返回迭代一个批次的数据。  
# imgs_1_batch.shape     # torch.Size([16, 3, 96, 96])   batch_size=16, channel=3, w=h=96
# label_1_batch.shape  # torch.Size([16])

index = np.random.permutation(len(all_imgs_path))
# index    # array([ 794,  909,  275, ...,  426, 1037,  857])   利用乱序的index对img和标签同时索引
all_imgs_path = np.array(all_imgs_path)[index]   # 这样,所有的图片按照index进行索引
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path)*0.8)   # 取出80%. 这样切分的前提是必须对数据做乱序
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydataset(test_imgs,test_labels,transform)

train_dl = data.DataLoader(train_ds,batch_size=16,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=16)

# 创建子类,使用子类对dataset进行灵活转换,而不需要重新创建。
class New_dataset(data.Dataset):
    def __init__(self,some_dataset):
        self.ds = some_dataset
    def __getitem__(self,index):    #使用index进行切片
        img,label = self.ds[index]
        img = img.permute(1,2,0)    # 将channel换到最后一维/  hwc
        return img,label
    def __len__(self):
        return len(self.ds)

train_new_ds = New_dataset(train_ds)
test_new_ds = New_dataset(test_ds)

相关文章

网友评论

      本文标题:PyTrch深度学习简明实战14 - 图片输入的第二种方式(划分

      本文链接:https://www.haomeiwen.com/subject/wywnrdtx.html