批处理

作者: zyyupup | 来源:发表于2018-12-14 20:07 被阅读0次
import torch
import torch.utils.data as Data
torch.manual_seed(2)

BATCH_SIZE = 8

if __name__ == '__main__':

    x = torch.linspace(1,10,10)
    y = torch.linspace(10,1,10)

    #先转换成torch能识别的dataset
    torch_dataset = Data.TensorDataset(x,y)
    #把dataset放入DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               # 要不要打乱数据 (打乱比较好)
        num_workers=2,              # 多线程来读数据
    )
    for epoch in range(3):
        for step,(batch_x,batch_y) in enumerate(loader):
            print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.numpy(),'| batch y:',batch_y.numpy())

'''
Epoch: 0 | Step: 0 | batch x: [2. 1. 9. 4. 5. 6. 7. 8.] | batch y: [ 9. 10.  2.  7.  6.  5.  4.  3.]
Epoch: 0 | Step: 1 | batch x: [ 3. 10.] | batch y: [8. 1.]
Epoch: 1 | Step: 0 | batch x: [ 6.  5. 10.  7.  3.  4.  1.  9.] | batch y: [ 5.  6.  1.  4.  8.  7. 10.  2.]
Epoch: 1 | Step: 1 | batch x: [8. 2.] | batch y: [3. 9.]
Epoch: 2 | Step: 0 | batch x: [ 2.  8.  7.  5.  1.  9. 10.  3.] | batch y: [ 9.  3.  4.  6. 10.  2.  1.  8.]
Epoch: 2 | Step: 1 | batch x: [6. 4.] | batch y: [5. 7.]
'''

相关文章

网友评论

      本文标题:批处理

      本文链接:https://www.haomeiwen.com/subject/wbughqtx.html