美文网首页
神经网络的批量处理

神经网络的批量处理

作者: 钢笔先生 | 来源:发表于2019-08-04 14:02 被阅读0次

Time : 2019-08-04
链接:https://www.youtube.com/watch?v=p1xZ2yWU1eo&list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&index=23

数据批量处理

这就需要借助于DataLoader了。

data_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=10
)

# 取一个batch
batch = next(iter(data_loader))
images, labels = batch

images.shape # (batch_size, in_channels, height, width)
# torch.Size([10, 1, 28, 28])
# 批量预测
preds = net(images)
preds.shape # torch.Size([10, 10])
preds.argmax(dim=1)
preds.argmax(dim=1).eq(labels)# .sum()

END.

相关文章

网友评论

      本文标题:神经网络的批量处理

      本文链接:https://www.haomeiwen.com/subject/noxydctx.html