pytorch可以自己定义 Dataset类, 然后用dataloader 函数来获取输入以及对应标签。下面是个简单的例子:
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
class TrainDataset(Dataset):
def __init__(self, root_dir, csv_file, transform):
self.root_dir = root_dir
self.labels = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return self.labels.shape[0]
def __getitem__(self, index):
img_name = os.path.join(self.root_dir,
self.labels.iloc[index, 0])
image = Image.open(img_name+'.jpg')
label = self.labels.iloc[index,1:].astype(int).to_numpy()
label = np.argmax(label)
if self.transform:
image = self.transform(image)
return image, label
dataset = TrainDataset(
root_dir='./data/Input',
csv_file=csv_file,
transform=transforms.Compose([
transforms.Resize(224, 224),
transforms.HorizontalFlip(p=0.5),
transforms.VerticalFlip(p=0.5),
transforms.Rotate(limit=(-90,90)),
transforms.RandomBrightnessContrast(),
])
)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True,num_workers=1)
for inputs, labels in data_loader:
img = torchvision.utils.make_grid(inputs[0])
img_nm = img.numpy()
img_trans = np.transpose(img_nm, (1, 2, 0))
plt.imshow(img_trans)
plt.show()
这样就可以使用自己定义的数据集了。
但是如果想要让数据集保持自己原来的尺寸,也就是说如果不用 transforms.Resize(224, 224), 把图片都缩放到224,而是保持他们原来各自不同的尺寸,需要怎么做呢?
只需要加一个自定义的collate_fn函数就可以了。在默认情况下,pytorch将图片叠在一起,成为一个NCH*W的张量,因此每个batch里的每个图像必须是相同的尺寸。所以如果想要接受不同尺寸的输入图片,我们就要自己定义collate_fn。
对于图像分类,collate_fn的输入大小是batch_size 大小的list, list里每个元素是一个元组,元组里第一个是图片,第二个是标签。对于不同大小的输入图片,我们可以使用list来储存。具体实现如下(Dataset类里面去掉resize):
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, collate_fn =my_collate)
trainiter = iter(data_loader)
imgs, labels = trainiter.next()
然后就可以得到保留了原尺寸的图片了。
不过要注意这里得到的 imgs是一个list,用的时候注意数据类型。
网友评论