关键在于yield的用法,廖雪峰老师的这篇文章解释得非常清楚详细。以下是生成batch训练训练集的简单方法:
方法一:
train_data = torch.tensor(...)
def data_iter(batch_size, train_data, train_labels):
num_examples = len(train_data)
indices = list(range(num_examples))
random.shuffle(indices) # random read 10 samples
for i in range(0, num_examples, batch_size):
j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # the last time may be not enough for a whole batch
yield train_data.index_select(0, j), train_labels.index_select(0, j)
方法二:
# combine featues and labels of dataset
dataset = Data.TensorDataset(features, labels)
# put dataset into DataLoader
data_iter = Data.DataLoader(
dataset=dataset, # torch TensorDataset format
batch_size=batch_size, # mini batch size
shuffle=True, # whether shuffle the data or not
num_workers=2, # read data in multithreading
)
使用方法分别为:
# 方法一
for X, y in data_iter(batch_size, train_data, train_labels):
pass
# 方法二
for X, y in data_iter:
网友评论