批处理

作者: zyyupup | 来源:发表于2018-12-14 20:07 被阅读0次
    import torch
    import torch.utils.data as Data
    torch.manual_seed(2)
    
    BATCH_SIZE = 8
    
    if __name__ == '__main__':
    
        x = torch.linspace(1,10,10)
        y = torch.linspace(10,1,10)
    
        #先转换成torch能识别的dataset
        torch_dataset = Data.TensorDataset(x,y)
        #把dataset放入DataLoader
        loader = Data.DataLoader(
            dataset=torch_dataset,      # torch TensorDataset format
            batch_size=BATCH_SIZE,      # mini batch size
            shuffle=True,               # 要不要打乱数据 (打乱比较好)
            num_workers=2,              # 多线程来读数据
        )
        for epoch in range(3):
            for step,(batch_x,batch_y) in enumerate(loader):
                print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.numpy(),'| batch y:',batch_y.numpy())
    
    '''
    Epoch: 0 | Step: 0 | batch x: [2. 1. 9. 4. 5. 6. 7. 8.] | batch y: [ 9. 10.  2.  7.  6.  5.  4.  3.]
    Epoch: 0 | Step: 1 | batch x: [ 3. 10.] | batch y: [8. 1.]
    Epoch: 1 | Step: 0 | batch x: [ 6.  5. 10.  7.  3.  4.  1.  9.] | batch y: [ 5.  6.  1.  4.  8.  7. 10.  2.]
    Epoch: 1 | Step: 1 | batch x: [8. 2.] | batch y: [3. 9.]
    Epoch: 2 | Step: 0 | batch x: [ 2.  8.  7.  5.  1.  9. 10.  3.] | batch y: [ 9.  3.  4.  6. 10.  2.  1.  8.]
    Epoch: 2 | Step: 1 | batch x: [6. 4.] | batch y: [5. 7.]
    '''
    
    

    相关文章

      网友评论

          本文标题:批处理

          本文链接:https://www.haomeiwen.com/subject/wbughqtx.html