Time : 2019-08-04
链接:https://www.youtube.com/watch?v=p1xZ2yWU1eo&list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&index=23
数据批量处理
这就需要借助于DataLoader
了。
data_loader = torch.utils.data.DataLoader(
train_set,
batch_size=10
)
# 取一个batch
batch = next(iter(data_loader))
images, labels = batch
images.shape # (batch_size, in_channels, height, width)
# torch.Size([10, 1, 28, 28])
# 批量预测
preds = net(images)
preds.shape # torch.Size([10, 10])
preds.argmax(dim=1)
preds.argmax(dim=1).eq(labels)# .sum()
END.
网友评论