美文网首页
PyTrch深度学习简明实战13 - 图片输入的第二种方式

PyTrch深度学习简明实战13 - 图片输入的第二种方式

作者: 薛东弗斯 | 来源:发表于2023-03-25 21:20 被阅读0次

    学习笔记15:第二种加载数据的方法 - pbc的成长之路 - 博客园 (cnblogs.com)

    自定义创建DataSet子类

    继承至data.Dataset父类,并且创建getitem,和len方法。

    import torch
    from torch.utils import data
    from PIL import Image   #  pip install pillow
    import numpy as np
    from torchvision import transforms
    import matplotlib.pyplot as plt
    %matplotlib inline
    import glob
    
    # 取出所有路径
    all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
    # all_imgs_path[:5]
    # ['./data/dataset2\\cloudy1.jpg',
    #  './data/dataset2\\cloudy10.jpg',
    #  './data/dataset2\\cloudy100.jpg',
    #  './data/dataset2\\cloudy101.jpg',
    #  './data/dataset2\wea\cloudy102.jpg']
    
    # 获得所有标签
    species = ['cloudy', 'rain', 'shine', 'sunrise']
    all_labels = []
    for img in all_imgs_path:
        for i, c in enumerate(species):
            if c in img:
                all_labels.append(i)
                
    # 定义数据集类
    # 必须创建 __getitem__, __len__, __init__
    class Mydataset(data.Dataset):
        def __init__(self, root):
            self.imgs_path = root    # self.imgs_path 为所有图片的总路径
        def __getitem__(self, index):
            img_path = self.imgs_path[index]   # img_path  为单个图片路径
            return img_path
        def __len__(self):
            return len(self.imgs_path)
        
    weather_dataset = Mydataset(all_imgs_path)
    # print(len(weather_dataset))      # 1122  共1122张图片
    # weather_dataset[1]   # 可以对图片进行切片  './data/dataset2\\cloudy10.jpg'
    
    wh_dl = torch.utils.data.DataLoader(weather_dataset,batch_size=4)   # 创建dataloader
    next(iter(wh_dl))   # 调用next方法,返回迭代一个批次的数据。  一个批次,一次返回4张图片的路径
    # ['./data/dataset2\\cloudy1.jpg',
    #  './data/dataset2\\cloudy10.jpg',
    #  './data/dataset2\\cloudy100.jpg',
    #  './data/dataset2\\cloudy101.jpg']
    

    举例说明,如何用

    获取标签

    # pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
    # 使用glob.glob取出所有路径
    all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
    
    # 获得所有图片的标签
    species = ['cloudy', 'rain', 'shine', 'sunrise']
    species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
    # species_to_dix    # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
    idx_to_species = dict((v,k) for k,v in species_to_idx.items())   # 对字典的items迭代,将变换后的结果复原
    # idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
    all_labels = []
    for img in all_imgs_path:           # 对所有的图片路径进行迭代。img代表其中一张图片的路径
        for i, c in enumerate(species): 
            if c in img:                # 如果类别(cloud/rain...)在路径里面
                all_labels.append(i)    # 将对应的类别序号append到all_labels 列表
    

    绘制图片

    import torch
    from torch.utils import data
    from PIL import Image   #  pip install pillow. python2中叫PIL
    import numpy as np
    from torchvision import transforms
    import matplotlib.pyplot as plt
    %matplotlib inline
    import glob
    
    # pytorch读取图片的方法都是通过Image获取,通过transforms进行转换
    # 使用glob.glob取出所有路径
    all_imgs_path = glob.glob(r'./data/dataset2/*.jpg')
    
    # 获得所有图片的标签
    species = ['cloudy', 'rain', 'shine', 'sunrise']
    species_to_idx = dict((c,i) for i,c in enumerate(species)) # enumerate会返回分类与位置。将类别数值化
    # species_to_dix    # {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
    idx_to_species = dict((v,k) for k,v in species_to_idx.items())   # 对字典的items迭代,将变换后的结果复原
    # idx_to_species #{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}
    all_labels = []
    for img in all_imgs_path:           # 对所有的图片路径进行迭代。img代表其中一张图片的路径
        for i, c in enumerate(species): 
            if c in img:                # 如果类别(cloud/rain...)在路径里面
                all_labels.append(i)    # 将对应的类别序号append到all_labels 列表
                
    # 使用transform对图片进行转换
    transform = transforms.Compose([
        transforms.Resize((96,96)),
        transforms.ToTensor()
    ])
                
    # 定义数据集类
    # 必须创建 __getitem__, __len__, __init__
    class Mydataset(data.Dataset):
        def __init__(self, img_paths,labels,transform):
            self.imgs = img_paths    # self.imgs 为所有图片的总路径
            self.labels = labels
            self.transforms = transform
        def __getitem__(self, index):   # 加索引,返回的是图片这个对象。 先读取,再转换后返回。
            img = self.imgs[index]      # 对img进行切片
            label = self.labels[index]  # 对labels进行切片
            pil_img = Image.open(img)     # 
            pil_img = pil_img.convert('RGB')  #防止图片中掺杂黑白图片。这一步建议加
            data = self.transforms(pil_img)   # 将每张图片进行Resize/To Tensor
            return data,label
        def __len__(self):
            return len(self.imgs)
        
    weather_dataset = Mydataset(all_imgs_path,all_labels,transform)
    weather_dl = data.DataLoader(weather_dataset,batch_size=16,shuffle=True)   #有几个计算核心,num_workers设置为几
    imgs_1_batch,labels_1_batch = next(iter(weather_dl))   # 调用next方法,返回迭代一个批次的数据。  
    # imgs_1_batch.shape     # torch.Size([16, 3, 96, 96])   batch_size=16, channel=3, w=h=96
    # label_1_batch.shape  # torch.Size([16])
    

    绘制前6张

    plt.figure(figsize=(12,8))
    for i,(img,label) in enumerate(zip(imgs_1_batch[:6],labels_1_batch[:6])):  # 绘制前6张
        img = img.permute(1,2,0).numpy()    # permute方法用于更改顺序,将channel放到后面
        plt.subplot(2,3,i+1)      # 2行3列,从第1开始
        plt.title(idx_to_species.get(label.item()))  # 获取id对应的分类
        plt.imshow(img)
    
    image.png

    绘制后6张

    plt.figure(figsize=(12,8))
    for i,(img,label) in enumerate(zip(imgs_1_batch[-6:],labels_1_batch[-6:])):  # 绘制后6张
        img = img.permute(1,2,0).numpy()    # permute方法用于更改顺序,将channel放到后面
        plt.subplot(2,3,i+1)      # 2行3列,从第1开始
        plt.title(idx_to_species.get(label.item()))  # 获取id对应的分类
        plt.imshow(img)
    
    image.png

    相关文章

      网友评论

          本文标题:PyTrch深度学习简明实战13 - 图片输入的第二种方式

          本文链接:https://www.haomeiwen.com/subject/xwdnrdtx.html