美文网首页
TF - 数据生成器

TF - 数据生成器

作者: 大地瓜_ | 来源:发表于2019-01-11 21:07 被阅读0次

    生成器

    • ADEChallengeData数据集

    数据的下载链接: http://sceneparsing.csail.mit.edu/results2016.html

    ADEChallengeData
       |
       | - images
       |       |
       |       | - training        20210
       |       | - validation      2000
       |
       | - annotations  
              |
              | - training          20210
              | - validation        2000
    
    

    训练集和验证集分开

    • 批次生成数据的核心思路
      由于图像太大,直接读入内存是不可取的,所以采用将图像的name和标签mask读入数组,然后将图像名放入任务队列,最后将从打乱的任务队列中随机读取图像的name,最后根据name重新open读取图片。
    第一步  构造任务队列  
          
            Image_Obj_1  |   Image_Obj_2 | ... |  Image_Obj_N
    
           *  Image_Obj : { image : "1.jpg",  "annotation": "1.png"}
    
    第二步  随机读取batch_size数据
           
           Image_Obj_x | ... |  Image_Obj_x+n
         
    第三步 读取batch_size的实际图像
    
          image = read( Image_Obj_x .image)   
          annotation = read( Image_Obj_x .annotation)   
    
    # -*- coding:utf-8 -*-
    import numpy as np
    import os
    import random
    from six.moves import cPickle as pickle
    from tensorflow.python.platform import gfile
    import glob
    
    import TensorflowUtils as utils
    
    # DATA_URL = 'http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip'
    DATA_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
    
    
    # input_dir = FLAGS.data_dir = MIT_SceneParsing
    # data_dir = MIT_SceneParsing/
    def read_dataset(data_dir):
        pickle_filename = "MITSceneParsing.pickle"
        pickle_filepath = os.path.join(data_dir, pickle_filename)
        print pickle_filepath
        if not os.path.exists(pickle_filepath):
            utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)
            #splitext 分离文件与扩展名
            SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]
            # 输入路径为  MIT_SceneParsing/ADEChallengeData2016
            # result 是训练数据和验证集得到字典组成
            result = create_image_lists(os.path.join(data_dir, SceneParsing_folder))
            print ("Pickling ...")
            with open(pickle_filepath, 'wb') as f:
                pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
        else:
            print ("Found pickle file!")
    
        with open(pickle_filepath, 'rb') as f:
            result = pickle.load(f)
            training_records = result['training']
            validation_records = result['validation']
            del result
    
        # train_records 是训练数据集的列表字典
        # validation_records  是验证数据集的列表字典
        return training_records, validation_records
    
    
    def create_image_lists(image_dir):
    
        '''
        这个函数用来返回生成训练和验证集
        image['training'] = [{'image': f, 'annotation': annotation_file, 'filename': filename},....,...,...]
        :param image_dir:
        :return:
        '''
        if not gfile.Exists(image_dir):
            print("Image directory '" + image_dir + "' not found.")
            return None
        directories = ['training', 'validation']
        image_list = {}
    
        for directory in directories:
            file_list = []
            image_list[directory] = []
            file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
            file_list.extend(glob.glob(file_glob))
    
            if not file_list:
                print('No files found')
            else:
                for f in file_list:
                    filename = os.path.splitext(f.split("/")[-1])[0]
                    annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.png')
                    if os.path.exists(annotation_file):
                        record = {'image': f, 'annotation': annotation_file, 'filename': filename}
                        image_list[directory].append(record)
                    else:
                        print("Annotation file not found for %s - Skipping" % filename)
    
            random.shuffle(image_list[directory])
            no_of_images = len(image_list[directory])
            print ('No. of %s files: %d' % (directory, no_of_images))
    
        return image_list
    
    import numpy as np
    import scipy.misc as misc
    
    class BatchDatset:
        files = []
        images = []
        annotations = []
        image_options = {}
        batch_offset = 0
        epochs_completed = 0
    
        def __init__(self, records_list, image_options={}):
            """
            Intialize a generic file reader with batching for list of files
            :param records_list: list of file records to read -
            sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
            :param image_options: A dictionary of options for modifying the output image
            Available options:
            resize = True/ False
            resize_size = #size of output image - does bilinear resize
            color=True/False
            """
            print("Initializing Batch Dataset Reader...")
            print(image_options)
            self.files = records_list
            self.image_options = image_options
            self._read_images()
    
        def _read_images(self):
            self.__channels = True
            self.images = np.array([self._transform(filename['image']) for filename in self.files])
            self.__channels = False
            self.annotations = np.array(
                [np.expand_dims(self._transform(filename['annotation']), axis=3) for filename in self.files])
            print (self.images.shape)
            print (self.annotations.shape)
    
        def _transform(self, filename):
            image = misc.imread(filename)
            if self.__channels and len(image.shape) < 3:  # make sure images are of shape(h,w,3)
                image = np.array([image for i in range(3)])
    
            if self.image_options.get("resize", False) and self.image_options["resize"]:
                resize_size = int(self.image_options["resize_size"])
                resize_image = misc.imresize(image,
                                             [resize_size, resize_size], interp='nearest')
            else:
                resize_image = image
    
            return np.array(resize_image)
    
        def get_records(self):
            return self.images, self.annotations
    
        def reset_batch_offset(self, offset=0):
            self.batch_offset = offset
    
        def next_batch(self, batch_size):
            start = self.batch_offset
            self.batch_offset += batch_size
            if self.batch_offset > self.images.shape[0]:
                # Finished epoch
                self.epochs_completed += 1
                print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
                # Shuffle the data
                perm = np.arange(self.images.shape[0])
                np.random.shuffle(perm)
                self.images = self.images[perm]
                self.annotations = self.annotations[perm]
                # Start next epoch
                start = 0
                self.batch_offset = batch_size
    
            end = self.batch_offset
            return self.images[start:end], self.annotations[start:end]
    
        def get_random_batch(self, batch_size):
            indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
            return self.images[indexes], self.annotations[indexes]
    

    相关文章

      网友评论

          本文标题:TF - 数据生成器

          本文链接:https://www.haomeiwen.com/subject/bmsfdqtx.html