import torch
import torch.utils.data as Data
torch.manual_seed(2)
BATCH_SIZE = 8
if __name__ == '__main__':
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
#先转换成torch能识别的dataset
torch_dataset = Data.TensorDataset(x,y)
#把dataset放入DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
)
for epoch in range(3):
for step,(batch_x,batch_y) in enumerate(loader):
print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.numpy(),'| batch y:',batch_y.numpy())
'''
Epoch: 0 | Step: 0 | batch x: [2. 1. 9. 4. 5. 6. 7. 8.] | batch y: [ 9. 10. 2. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 3. 10.] | batch y: [8. 1.]
Epoch: 1 | Step: 0 | batch x: [ 6. 5. 10. 7. 3. 4. 1. 9.] | batch y: [ 5. 6. 1. 4. 8. 7. 10. 2.]
Epoch: 1 | Step: 1 | batch x: [8. 2.] | batch y: [3. 9.]
Epoch: 2 | Step: 0 | batch x: [ 2. 8. 7. 5. 1. 9. 10. 3.] | batch y: [ 9. 3. 4. 6. 10. 2. 1. 8.]
Epoch: 2 | Step: 1 | batch x: [6. 4.] | batch y: [5. 7.]
'''
网友评论