train_dataset = torch.utils.data.TensorDataset(train, label)
batch_size = 5
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data), Variable(target)
添加:
data = data.float()
网友评论