美文网首页
Python的迭代器学习

Python的迭代器学习

作者: 星光下的胖子 | 来源:发表于2021-04-16 09:00 被阅读0次

1.可迭代对象与迭代器

1)含有内置函数__iter__()是可迭代对象Iterable。
2)同时含有__iter__()__next__()是迭代器Iterator。
因此,迭代器一定是可迭代对象。

2.list的研究

list含有__iter__(),是一个可迭代对象Iterable;
但list不含有__next__(),不是一个迭代器Iterator。

3.for的迭代过程

1)先调用iter()得到一个迭代器it;
2)然后对这个迭代器it进行next(it)迭代操作。
并且,for操作会自动捕获__next__()函数抛出的StopIteration异常。
注意:只能捕获StopIteration异常,不能自动捕获其他类型的异常。

4.iter和next的用法

for操作会自动捕获StopIteration异常,但next操作不会。

5.模拟list类,创建一个Food类

6.模拟创建Dataset和DataLoader,用于管理数据集

import numpy as np
import random
import math


class Dataset(object):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, item):
        return self.images[item], self.labels[item]
    
class DataLoaderIter(object):
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.cursor = 0
        self.indexs = list(range(self.dataloader.count_data))
        if self.dataloader.shuffle:
            random.shuffle(self.indexs)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.cursor < self.dataloader.count_data:
            items = self.indexs[self.cursor:self.cursor + self.dataloader.batch_size]
            batch_data = self.mergy_data(items)  # 根据索引列表获取对应的数据
            self.cursor += self.dataloader.batch_size
            return batch_data
        else:
            raise StopIteration  # 手动抛出StopIteration异常
          
    # 根据索引列表获取对应的数据
    def mergy_data(self, items):
        images_list = []
        labels_list = []
        for item in items:
            image, label = self.dataloader.dataset[item]
            images_list.append(image)
            labels_list.append(label)
        return np.vstack(images_list), np.vstack(labels_list).squeeze()

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.count_data = len(dataset)
        if drop_last:
            self.count_data = math.floor(self.count_data / self.batch_size) * self.batch_size
        
    def __len__(self):
        return math.ceil(self.count_data / self.batch_size)
    
    def __iter__(self):
        return DataLoaderIter(self)

指定batch_size,将数据集分为多个批次,进行遍历操作。

相关文章

网友评论

      本文标题:Python的迭代器学习

      本文链接:https://www.haomeiwen.com/subject/fpwjlltx.html