美文网首页
Python的迭代器学习

Python的迭代器学习

作者: 星光下的胖子 | 来源:发表于2021-04-16 09:00 被阅读0次

    1.可迭代对象与迭代器

    1)含有内置函数__iter__()是可迭代对象Iterable。
    2)同时含有__iter__()__next__()是迭代器Iterator。
    因此,迭代器一定是可迭代对象。

    2.list的研究

    list含有__iter__(),是一个可迭代对象Iterable;
    但list不含有__next__(),不是一个迭代器Iterator。

    3.for的迭代过程

    1)先调用iter()得到一个迭代器it;
    2)然后对这个迭代器it进行next(it)迭代操作。
    并且,for操作会自动捕获__next__()函数抛出的StopIteration异常。
    注意:只能捕获StopIteration异常,不能自动捕获其他类型的异常。

    4.iter和next的用法

    for操作会自动捕获StopIteration异常,但next操作不会。

    5.模拟list类,创建一个Food类

    6.模拟创建Dataset和DataLoader,用于管理数据集

    import numpy as np
    import random
    import math
    
    
    class Dataset(object):
        def __init__(self, images, labels):
            self.images = images
            self.labels = labels
            
        def __len__(self):
            return len(self.images)
        
        def __getitem__(self, item):
            return self.images[item], self.labels[item]
        
    class DataLoaderIter(object):
        def __init__(self, dataloader):
            self.dataloader = dataloader
            self.cursor = 0
            self.indexs = list(range(self.dataloader.count_data))
            if self.dataloader.shuffle:
                random.shuffle(self.indexs)
        
        def __iter__(self):
            return self
        
        def __next__(self):
            if self.cursor < self.dataloader.count_data:
                items = self.indexs[self.cursor:self.cursor + self.dataloader.batch_size]
                batch_data = self.mergy_data(items)  # 根据索引列表获取对应的数据
                self.cursor += self.dataloader.batch_size
                return batch_data
            else:
                raise StopIteration  # 手动抛出StopIteration异常
              
        # 根据索引列表获取对应的数据
        def mergy_data(self, items):
            images_list = []
            labels_list = []
            for item in items:
                image, label = self.dataloader.dataset[item]
                images_list.append(image)
                labels_list.append(label)
            return np.vstack(images_list), np.vstack(labels_list).squeeze()
    
    class DataLoader(object):
        def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
            self.dataset = dataset
            self.batch_size = batch_size
            self.shuffle = shuffle
            self.drop_last = drop_last
            self.count_data = len(dataset)
            if drop_last:
                self.count_data = math.floor(self.count_data / self.batch_size) * self.batch_size
            
        def __len__(self):
            return math.ceil(self.count_data / self.batch_size)
        
        def __iter__(self):
            return DataLoaderIter(self)
    

    指定batch_size,将数据集分为多个批次,进行遍历操作。

    相关文章

      网友评论

          本文标题:Python的迭代器学习

          本文链接:https://www.haomeiwen.com/subject/fpwjlltx.html