美文网首页
pytorch学习笔记-dataloader输出不同尺寸的输入图

pytorch学习笔记-dataloader输出不同尺寸的输入图

作者: 升不上三段的大鱼 | 来源:发表于2020-08-18 05:46 被阅读0次

pytorch可以自己定义 Dataset类, 然后用dataloader 函数来获取输入以及对应标签。下面是个简单的例子:

from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader

class TrainDataset(Dataset):
    def __init__(self, root_dir, csv_file, transform):

        self.root_dir = root_dir
        self.labels = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        img_name = os.path.join(self.root_dir,
                                self.labels.iloc[index, 0])
        image = Image.open(img_name+'.jpg')
        label = self.labels.iloc[index,1:].astype(int).to_numpy()
        label = np.argmax(label)

        if self.transform:
            image = self.transform(image)

        return image, label

 dataset = TrainDataset(
        root_dir='./data/Input',
        csv_file=csv_file,
        transform=transforms.Compose([
           transforms.Resize(224, 224),
           transforms.HorizontalFlip(p=0.5),
           transforms.VerticalFlip(p=0.5),
           transforms.Rotate(limit=(-90,90)),
           transforms.RandomBrightnessContrast(),
        ])
    )

data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,num_workers=1)

for inputs, labels in data_loader:
     img = torchvision.utils.make_grid(inputs[0])
     img_nm = img.numpy()
     img_trans = np.transpose(img_nm, (1, 2, 0))
     plt.imshow(img_trans)
     plt.show()

这样就可以使用自己定义的数据集了。

但是如果想要让数据集保持自己原来的尺寸,也就是说如果不用 transforms.Resize(224, 224), 把图片都缩放到224,而是保持他们原来各自不同的尺寸,需要怎么做呢?

只需要加一个自定义的collate_fn函数就可以了。在默认情况下,pytorch将图片叠在一起,成为一个NCH*W的张量,因此每个batch里的每个图像必须是相同的尺寸。所以如果想要接受不同尺寸的输入图片,我们就要自己定义collate_fn。
对于图像分类,collate_fn的输入大小是batch_size 大小的list, list里每个元素是一个元组,元组里第一个是图片,第二个是标签。对于不同大小的输入图片,我们可以使用list来储存。具体实现如下(Dataset类里面去掉resize):

def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]

data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True, collate_fn =my_collate)
trainiter = iter(data_loader)
imgs, labels = trainiter.next()

然后就可以得到保留了原尺寸的图片了。
不过要注意这里得到的 imgs是一个list,用的时候注意数据类型。

相关文章

  • pytorch学习笔记-dataloader输出不同尺寸的输入图

    pytorch可以自己定义 Dataset类, 然后用dataloader 函数来获取输入以及对应标签。下面是个简...

  • Dataloader重要参数与内部机制

    @[TOC] 一、pytorch数据输入 Dataset负责生产数据,DataLoader负责数据的分批(batc...

  • pytorch学习笔记-dataloader忽略异常值

    在使用自己的数据的时候,如果希望输入的数据满足一些条件,不满足条件的数据不会用于训练,一个方法是预处理,把不满足条...

  • 循环神经网络pytorch实现

    RNN pytorch 实现 LSTM 输入门: 遗忘门: 输出门: pytorch 实现 GRU 更新门: 候选...

  • 生动形象的DataLoader

    整理一下 PyTorch 的 DataLoader 。 先来看看官方文档: PyTorch 出这两个类的目的是想将...

  • 3、学习误区

    学习误区图 英语为例子 英语学习的正确输出输入图 没有中文翻译环节,直接到位 错误的输入输出图 中文的介入导致思维...

  • SPP金字塔池化

    为了满足不同尺度图像的输出,提出了金字塔池化的方法。 具体来说操作是这样的,我们的输入特征图尺寸是变化的,但是由于...

  • pytorch 函数DataLoader

    参考链接:https://www.e-learn.cn/content/qita/850153本文仅作为学术分享,...

  • Pytorch dataloader用法

    献给莹莹 该文章适用于两种方法读取数据集,采用Pytorch框架参考文献:https://www.cnblogs....

  • pytorch--1数据加载

    构建数据Dataset和DataLoader 构建网络 参考: PyTorch之保存加载模型pytorchyolo...

网友评论

      本文标题:pytorch学习笔记-dataloader输出不同尺寸的输入图

      本文链接:https://www.haomeiwen.com/subject/gpxhjktx.html