当你从硬盘加载大数据,或做组合特征,是比较适合用数据生成器
列举了几个常用的数据生成器方式
1.1 直接使用python的生成器生成数据
比较简单,但是缺少很多keras内建的功能
#X,保存的是ID
#y,保存的是标签
#transform_function,保存的是最X转换方案
def generate(X,y,transform_function,batch_size=64,dim,channel):
for i in range(len(X)//batch_size):
sample = np.zeros((batch_size,*dim,channel))
labels = np.zeros(batch_size)
sample[:] = [transform(x) for x in X[i*batch_size:(i+1)*batch_size]]
labels[:] = y[i*batch_size:(i+1)*batch_size]
yield sample,labels
1.2继承tf.keras.utils.Sequence类
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, df, x_col, y_col=None, batch_size=32, num_classes=None, shuffle=True):
self.batch_size = batch_size
self.df = dataframe
self.indices = self.df.index.tolist()
self.num_classes = num_classes
self.shuffle = shuffle
self.x_col = x_col
self.y_col = y_col
self.on_epoch_end()
def __len__(self):
return len(self.indices) // self.batch_size)
def __getitem__(self, index):
index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
batch = [self.indices[k] for k in index]
X, y = self.__get_data(batch)
return X, y
def on_epoch_end(self):
self.index = np.arange(len(self.indices))
if self.shuffle == True:
np.random.shuffle(self.index)
def __get_data(self, batch):
X = # logic
y = # logic
for i, id in enumerate(batch):
X[i,] = # logic
y[i] = # labels
return X, y
网友评论