defdata_iter(batch_size,features,labels):
num_examples=len(features)
indices=list(range(num_examples))
random.shuffle(indices)
for i in range(0,num_examples,batch_size):
j=nd.array(indices[i:min(i+batch_size,num_examples)])
yield features.take(j),labels.take(j) # take函数根据索引返回对应元素
网友评论