1.可迭代对象与迭代器
1)含有内置函数__iter__()
是可迭代对象Iterable。
2)同时含有__iter__()
和__next__()
是迭代器Iterator。
因此,迭代器一定是可迭代对象。
2.list的研究
list含有__iter__()
,是一个可迭代对象Iterable;
但list不含有__next__()
,不是一个迭代器Iterator。
3.for的迭代过程
1)先调用iter()得到一个迭代器it;
2)然后对这个迭代器it进行next(it)迭代操作。
并且,for操作会自动捕获__next__()
函数抛出的StopIteration异常。
注意:只能捕获StopIteration异常,不能自动捕获其他类型的异常。
4.iter和next的用法
for操作会自动捕获StopIteration异常,但next操作不会。
5.模拟list类,创建一个Food类
6.模拟创建Dataset和DataLoader,用于管理数据集
import numpy as np
import random
import math
class Dataset(object):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, item):
return self.images[item], self.labels[item]
class DataLoaderIter(object):
def __init__(self, dataloader):
self.dataloader = dataloader
self.cursor = 0
self.indexs = list(range(self.dataloader.count_data))
if self.dataloader.shuffle:
random.shuffle(self.indexs)
def __iter__(self):
return self
def __next__(self):
if self.cursor < self.dataloader.count_data:
items = self.indexs[self.cursor:self.cursor + self.dataloader.batch_size]
batch_data = self.mergy_data(items) # 根据索引列表获取对应的数据
self.cursor += self.dataloader.batch_size
return batch_data
else:
raise StopIteration # 手动抛出StopIteration异常
# 根据索引列表获取对应的数据
def mergy_data(self, items):
images_list = []
labels_list = []
for item in items:
image, label = self.dataloader.dataset[item]
images_list.append(image)
labels_list.append(label)
return np.vstack(images_list), np.vstack(labels_list).squeeze()
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.count_data = len(dataset)
if drop_last:
self.count_data = math.floor(self.count_data / self.batch_size) * self.batch_size
def __len__(self):
return math.ceil(self.count_data / self.batch_size)
def __iter__(self):
return DataLoaderIter(self)
指定batch_size,将数据集分为多个批次,进行遍历操作。
网友评论