美文网首页
1.2 自定义keras数据生成器

1.2 自定义keras数据生成器

作者: 纵春水东流 | 来源:发表于2021-04-27 22:37 被阅读0次

    当你从硬盘加载大数据,或做组合特征,是比较适合用数据生成器
    列举了几个常用的数据生成器方式

    1.1 直接使用python的生成器生成数据
    比较简单,但是缺少很多keras内建的功能

    #X,保存的是ID
    #y,保存的是标签
    #transform_function,保存的是最X转换方案
    def generate(X,y,transform_function,batch_size=64,dim,channel):
        for i in range(len(X)//batch_size):
            sample = np.zeros((batch_size,*dim,channel))
            labels = np.zeros(batch_size)
    
            sample[:] = [transform(x) for x in X[i*batch_size:(i+1)*batch_size]] 
            labels[:] = y[i*batch_size:(i+1)*batch_size]
            
            yield sample,labels
    

    1.2继承tf.keras.utils.Sequence类

    class DataGenerator(tf.keras.utils.Sequence):
        def __init__(self, df, x_col, y_col=None, batch_size=32, num_classes=None, shuffle=True):
            self.batch_size = batch_size
            self.df = dataframe
            self.indices = self.df.index.tolist()
            self.num_classes = num_classes
            self.shuffle = shuffle
            self.x_col = x_col
            self.y_col = y_col
            self.on_epoch_end()
    
        def __len__(self):
            return len(self.indices) // self.batch_size)
    
        def __getitem__(self, index):
            index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
            batch = [self.indices[k] for k in index]
            
            X, y = self.__get_data(batch)
            return X, y
    
        def on_epoch_end(self):
            self.index = np.arange(len(self.indices))
            if self.shuffle == True:
                np.random.shuffle(self.index)
    
        def __get_data(self, batch):
            X = # logic
            y = # logic
            
            for i, id in enumerate(batch):
                X[i,] = # logic
                y[i] = # labels
    
            return X, y
    

    相关文章

      网友评论

          本文标题:1.2 自定义keras数据生成器

          本文链接:https://www.haomeiwen.com/subject/wpoprltx.html