代码摘自https://github.com/RandolphVI/Multi-Label-Text-Classification中的data_helper.py
def batch_iter(data, batch_size, num_epochs, shuffle=True):
"""
含有 yield 说明不是一个普通函数,是一个 Generator.
函数效果:对 data,一共分成 num_epochs 个阶段(epoch),在每个 epoch 内,如果 shuffle=True,就将 data 重新洗牌,
批量生成 (yield) 一批一批的重洗过的 data,每批大小是 batch_size,一共生成 int(len(data)/batch_size)+1 批。
Args:
data: The data
batch_size: The size of the data batch
num_epochs: The number of epochs
shuffle: Shuffle or not (default: True)
Returns:
A batch iterator for data set
"""
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
for epoch in range(num_epochs):
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
网友评论