美文网首页
1.2 自定义keras数据生成器

1.2 自定义keras数据生成器

作者: 纵春水东流 | 来源:发表于2021-04-27 22:37 被阅读0次

当你从硬盘加载大数据,或做组合特征,是比较适合用数据生成器
列举了几个常用的数据生成器方式

1.1 直接使用python的生成器生成数据
比较简单,但是缺少很多keras内建的功能

#X,保存的是ID
#y,保存的是标签
#transform_function,保存的是最X转换方案
def generate(X,y,transform_function,batch_size=64,dim,channel):
    for i in range(len(X)//batch_size):
        sample = np.zeros((batch_size,*dim,channel))
        labels = np.zeros(batch_size)

        sample[:] = [transform(x) for x in X[i*batch_size:(i+1)*batch_size]] 
        labels[:] = y[i*batch_size:(i+1)*batch_size]
        
        yield sample,labels

1.2继承tf.keras.utils.Sequence类

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, x_col, y_col=None, batch_size=32, num_classes=None, shuffle=True):
        self.batch_size = batch_size
        self.df = dataframe
        self.indices = self.df.index.tolist()
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.x_col = x_col
        self.y_col = y_col
        self.on_epoch_end()

    def __len__(self):
        return len(self.indices) // self.batch_size)

    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]
        
        X, y = self.__get_data(batch)
        return X, y

    def on_epoch_end(self):
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)

    def __get_data(self, batch):
        X = # logic
        y = # logic
        
        for i, id in enumerate(batch):
            X[i,] = # logic
            y[i] = # labels

        return X, y

相关文章

网友评论

      本文标题:1.2 自定义keras数据生成器

      本文链接:https://www.haomeiwen.com/subject/wpoprltx.html