文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要是关于PyTorch的一些用法。
import torch
import torch.utils.data as Data
from torch.autograd import Variable
# 定义batch size
BATCH_SIZE = 5
# 定义数据
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
print x.numpy()
print y.numpy()
[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.]
[ 10. 9. 8. 7. 6. 5. 4. 3. 2. 1.]
# 定义数据库
dataset = Data.TensorDataset(data_tensor = x, target_tensor = y)
# 定义数据加载器
loader = Data.DataLoader(dataset = dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)
# 训练过程
for epoch in xrange(5):
for step, (batch_x, batch_y) in enumerate(loader):
# 训练过程
print 'Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| betch y: ', batch_y.numpy()
Epoch: 0 | Step: 0 | batch x: [ 7. 4. 8. 5. 2.] | betch y: [ 4. 7. 3. 6. 9.]
Epoch: 0 | Step: 1 | batch x: [ 10. 6. 3. 1. 9.] | betch y: [ 1. 5. 8. 10. 2.]
Epoch: 1 | Step: 0 | batch x: [ 6. 7. 10. 1. 3.] | betch y: [ 5. 4. 1. 10. 8.]
Epoch: 1 | Step: 1 | batch x: [ 9. 4. 5. 8. 2.] | betch y: [ 2. 7. 6. 3. 9.]
Epoch: 2 | Step: 0 | batch x: [ 5. 4. 7. 3. 8.] | betch y: [ 6. 7. 4. 8. 3.]
Epoch: 2 | Step: 1 | batch x: [ 6. 9. 2. 10. 1.] | betch y: [ 5. 2. 9. 1. 10.]
Epoch: 3 | Step: 0 | batch x: [ 9. 1. 5. 3. 10.] | betch y: [ 2. 10. 6. 8. 1.]
Epoch: 3 | Step: 1 | batch x: [ 8. 6. 4. 2. 7.] | betch y: [ 3. 5. 7. 9. 4.]
Epoch: 4 | Step: 0 | batch x: [ 10. 5. 9. 7. 3.] | betch y: [ 1. 6. 2. 4. 8.]
Epoch: 4 | Step: 1 | batch x: [ 6. 8. 2. 4. 1.] | betch y: [ 5. 3. 9. 7. 10.]
网友评论