美文网首页
测试机器(无网络)加载cifar数据集

测试机器(无网络)加载cifar数据集

作者: 西瓜雪梨桔子汁 | 来源:发表于2018-12-10 16:24 被阅读0次

           PyTorch的数据集包揽了常用的数据集,使用时只需要引入、配置download=True即可联网下载。
           迫于手里测试集不能连接外网,只能查看源码的实现,以期自己下载数据文件、再自行导入数据集。在torchvision的源码中找到了datasets/cifar.py文件,负载下载及导入。其它支持的数据集加载源码如下:

    PyTorch支持数据集

    1. 下载数据文件

    查看源码可知道,如果获取cifar对象时,指定了download=True就会触发下载文件:

    # 如果download为True,联网下载数据文件
    if download:
         self.download()
    

    下载文件逻辑:

        def download(self):
            import tarfile
            # 如果已经下载且校验完成性的,直接返回
            if self._check_integrity():
                print('Files already downloaded and verified')
                return
    
            #  否则下载文件
            download_url(self.url, self.root, self.filename, self.tgz_md5)
    
            # extract file 下载后解压文件
            with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
                tar.extractall(path=self.root)
    

           下载文件的url为https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,下载之后存储文件名为:cifar-10-python.tar.gz,下载后需要校验文件的MD5值,为:c58f30108f718f92721af3b95e74349a。最重要的,下载之后需要解压文件、解压到指定的目录。文件可以从如下两个地址下载得到:

    https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
    

    上传到测试服务器指定目录:


    下载数据文件

    2.检查文件完整性

    # 检查文件完整性
    if not self._check_integrity():
        raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
    

    检查文件明细

        def _check_integrity(self):
            root = self.root
            # 检查训练数据文件、测试数据文件是否都在
            for fentry in (self.train_list + self.test_list):
                filename, md5 = fentry[0], fentry[1]
                fpath = os.path.join(root, self.base_folder, filename)
                if not check_integrity(fpath, md5):
                    return False
            return True
    

    具体的文件明细及其对应的MD5值也在类变量定义好了:

     train_list = [
            ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
            ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
            ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
            ['data_batch_4', '634d18415352ddfa80567beed471001a'],
            ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
        ]
    
        test_list = [
            ['test_batch', '40351d587109b95175f43aff81a1287e'],
        ]
    

    3.加载数据

            # 如果是 training set,使用train_list的文件列表
            if self.train:
                downloaded_list = self.train_list
            # 如果是 testset,使用test_list的文件列表
            else:
                downloaded_list = self.test_list
    
            self.data = []
            self.targets = []
    
            # now load the picked numpy arrays 开始加载数据文件
            for file_name, checksum in downloaded_list:
                file_path = os.path.join(self.root, self.base_folder, file_name)
                with open(file_path, 'rb') as f:
                    if sys.version_info[0] == 2:
                        entry = pickle.load(f)
                    else:
                        entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    if 'labels' in entry:
                        self.targets.extend(entry['labels'])
                    else:
                        self.targets.extend(entry['fine_labels'])
    
            # 将数据调整了尺寸:RGB3通道、大小32*32,不论图片张数
            self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
            # convert to HWC,转换为 H*W*C,将channel放到最后
            self.data = self.data.transpose((0, 2, 3, 1))  
    
            # 加载meta文件,得到类别信息
            self._load_meta()
    

           以加载训练集为例,会将文件:data_batch_1, data_batch_2, data_batch_3, data_batch_4, data_batch_5全部加载;这些文件路径会在os.path.join(self.root, self.base_folder)中。
           以我的环境为例,数据集目录都放在了工作目录下的cifar目录、base_folder是脚本指定的解压目录,目录名cifar-10-batches-py,拼接各个数据集文件名称就得到完成的数据文件全路径。

    cifar-10-batches-py文件明细
           然后,进行必要的数据调整,以适应tensor的格式,包括图片的大小、通道等等,便于后续进行batch操作。
           加载meta文件逻辑:
        def _load_meta(self):
            path = os.path.join(self.root, self.base_folder, self.meta['filename'])
            if not check_integrity(path, self.meta['md5']):
                raise RuntimeError('Dataset metadata file not found or corrupted.' +
                                   ' You can use download=True to download it')
            with open(path, 'rb') as infile:
                if sys.version_info[0] == 2:
                    data = pickle.load(infile)
                else:
                    data = pickle.load(infile, encoding='latin1')
                self.classes = data[self.meta['key']]
            self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
    

    meta文件配置:

    meta = {
            'filename': 'batches.meta',
            'key': 'label_names',
            'md5': '5ff9c542aee3614f3951f8cda6e48888',
        }
    

    可见训练集的meta文件为batches.meta,标签名叫label_names

    4.加载数据集方法

    #!/usr/bin/python
    # -*- coding: UTF-8 -*-
    
    import sys
    reload(sys)
    sys.setdefaultencoding("utf-8")
    import os,sys
    import numpy as np 
    import torch
    import torchvision
    from torchvision import transforms, datasets
    
    def get_cifar10_dataloader(data_dir):
        '''
        使用torch的DataLoader加载数据、规则化、分批
        '''
        # 准备数据集并预处理
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  #先四周填充0,再将图像随机裁剪成32*32
            transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #R,G,B每层的归一化用到的均值和方差
        ])
    
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    
       
        #训练数据集
        train_dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train)
        valid_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test)
    
        # 对数据进行分批
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=True)
        
        # 返回数据
        return train_loader, valid_loader
        
            
    if __name__ == "__main__":
        data_dir = '/app/quanyq/metis/torch/cifar'
        train_loader, valid_loader = get_cifar10_dataloader(data_dir)
        print '训练集大小:%d' % len(train_loader.dataset)
        print '测试集大小:%d' % len(valid_loader.dataset)
        print 'Finished...'
    
    加载数据集情况

    5.加载分类100种图片数据集

    def get_cifar100_dataloader(data_dir):
        '''
        使用torch的DataLoader加载数据、规则化、分批
        '''
        # 准备数据集并预处理
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  #先四周填充0,再将图像随机裁剪成32*32
            transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #R,G,B每层的归一化用到的均值和方差
        ])
    
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    
       
        #训练数据集:CIFAR100则会加载数据目录下cifar-100-python的文件
        train_dataset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=transform_train)
        valid_dataset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform_test)
    
        # 对数据进行分批
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=True)
        
        # 返回数据
        return train_loader, valid_loader
    

           代码结构一模一样,只是选择类为datasets.CIFAR100,这样内部处理会选择加载数据目录下cifar-100-python文件夹中的数据。

    cifar-100-python文件夹中的数据

    6.总结

           整体来说,PyTorch对cifar数据集支持比较好,这里只是跳过了下载、解压的流程,然后将数据目录传给datasets.CIFAR100或者datasets.CIFAR10类就可以完成数据加载。

    相关文章

      网友评论

          本文标题:测试机器(无网络)加载cifar数据集

          本文链接:https://www.haomeiwen.com/subject/eewrhqtx.html