美文网首页
神经网络的批量处理

神经网络的批量处理

作者: 钢笔先生 | 来源:发表于2019-08-04 14:02 被阅读0次

    Time : 2019-08-04
    链接:https://www.youtube.com/watch?v=p1xZ2yWU1eo&list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&index=23

    数据批量处理

    这就需要借助于DataLoader了。

    data_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=10
    )
    
    # 取一个batch
    batch = next(iter(data_loader))
    images, labels = batch
    
    images.shape # (batch_size, in_channels, height, width)
    # torch.Size([10, 1, 28, 28])
    # 批量预测
    preds = net(images)
    preds.shape # torch.Size([10, 10])
    preds.argmax(dim=1)
    preds.argmax(dim=1).eq(labels)# .sum()
    
    

    END.

    相关文章

      网友评论

          本文标题:神经网络的批量处理

          本文链接:https://www.haomeiwen.com/subject/noxydctx.html