美文网首页
数据生成器-适用大数据集小内存

数据生成器-适用大数据集小内存

作者: chunleiml | 来源:发表于2019-01-20 16:59 被阅读8次

    深度学习模型训练中,一般情况下需要很多的训练数据,如果这些数据训练时一次全部加载到内存,会需要很大的内存容量,因此我们一般使用数据生成器往内存根据训练需要逐步加载数据到内存

    #数据生成函数
    def load_train_data(self,para,index1,index2):
        #此路径下把每次需要加载进内存的数据单独保存在一个文件夹中,文件夹命名按整数从0开始
        input_path = r'D:\Data\generator24'
        while(para):
            i = random.randint(index1,index2)
            labels_stomach = input_path + os.sep + str(i) + os.sep + 'std_labels_abdomen3D_320.nii.gz'
            labels_stomach = nb.load(labels_stomach)
            labels_stomach = labels_stomach.get_data()
            imgs = input_path + os.sep + str(i) + os.sep + 'std_imgs3D_320.nii.gz'
            imgs = nb.load(imgs)    
            imgs = imgs.get_data()
            yield (imgs, labels_stomach)
    
          
    def tain_model(self):
        model = self.get_unet()
        model_checkpoint = ModelCheckpoint('seg_abdomen3D_24.h5',
                                           monitor='val_loss',
                                           verbose=1,
                                           save_best_only=True,
                                           save_weights_only = True,
                                           mode='min')
        check_list = [model_checkpoint, TensorBoard(log_dir='./log')]
        #使用keras的fit_generator,每次调用load_train_data都重新从本地路径加载数据进内存
        model.fit_generator(generator = self.load_train_data(True,0,700),
                            steps_per_epoch=600,
                            epochs=300,
                            validation_data=self.load_train_data(True,701,857),
                            validation_steps=156,
                            initial_epoch=0,
                            callbacks=check_list,
                            workers=8)
    

    相关文章

      网友评论

          本文标题:数据生成器-适用大数据集小内存

          本文链接:https://www.haomeiwen.com/subject/bdgxjqtx.html