美文网首页
pytorch-3dunet-master解读

pytorch-3dunet-master解读

作者: zelda2333 | 来源:发表于2020-06-14 17:25 被阅读0次

源代码:https://github.com/wolny/pytorch-3dunet

Installation

1.使用conda安装

2.运行setup.py 安装
有关setup.py可以看
Python学习笔记|python之setuptools
setup如何打包
如何将自己的Python程序打包--setuptools详解

setup.py

from setuptools import setup, find_packages
# 获得__version__.py里的内容,使得获取到__version__
exec(open('pytorch3dunet/__version__.py').read())
setup(
    name="pytorch3dunet",    # 包名称------------生成的egg名称
    # 自动动态获取packages,默认在和setup.py同一目录下搜索各个含有 init.py的包。exclude:打包的时,排除tests文件
    packages=find_packages(exclude=["tests"]),
    version=__version__,      # (-V) 包版本----生成egg包的版本号
    author="Adrian Wolny, Lorenzo Cerrone",
    url="https://github.com/wolny/pytorch-3dunet",  # 程序的官网地址
    license="MIT",
    python_requires='>=3.7'   # --requires 定义依赖哪些模块
)

训练

train.py

  • def main():

 0   logger = get_logger('UNet3DTrain')
def main():
      # Load and log experiment configuration
 1    config = load_config()
 2    logger.info(config)

 3    manual_seed = config.get('manual_seed', None)
 4    if manual_seed is not None:
 5        logger.info(f'Seed the RNG for all devices with {manual_seed}')
 6        torch.manual_seed(manual_seed)
          # see https://pytorch.org/docs/stable/notes/randomness.html
 7        torch.backends.cudnn.deterministic = True
 8        torch.backends.cudnn.benchmark = False

       # Create the model
 9   model = get_model(config)
     # use DataParallel if more than 1 GPU available
10   device = config['device']
11   if torch.cuda.device_count() > 1 and not device.type == 'cpu':
12        model = nn.DataParallel(model)
13        logger.info(f'Using {torch.cuda.device_count()} GPUs for training')

    # put the model on GPUs
14  logger.info(f"Sending the model to '{config['device']}'")
15  model = model.to(device)

    # Log the number of learnable parameters
16  logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
17  loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
18  eval_criterion = get_evaluation_metric(config)

    # Create data loaders
19  loaders = get_train_loaders(config)

    # Create the optimizer
20  optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
22  lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
23  trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
24                           loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)
    # Start training
25  trainer.fit()

1. config = load_config()
打印设备日志

1.config = load_config())

Python中 logging 日志详解

logger = utils.get_logger('ConfigLoader')

def load_config():
    parser = argparse.ArgumentParser(description='UNet3D')
    parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
    args = parser.parse_args()
    config = _load_config_yaml(args.config)  # 打开--config的文件
    # Get a device to train on
    device_str = config.get('device', None)
    if device_str is not None:
        logger.info(f"Device specified in config: '{device_str}'")
        -----------
        if device_str.startswith('cuda') and not torch.cuda.is_available():
            logger.warn('CUDA not available, using CPU')
            device_str = 'cpu'
    else:
        device_str = "cuda:0" if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using '{device_str}' device")
        -----------
    device = torch.device(device_str)
    config['device'] = device

def _load_config_yaml(config_file):
    return yaml.safe_load(open(config_file, 'r'))

--------------------------------------------------------------------------------------------------------------------------
import logging
loggers = {}
def get_logger(name, level=logging.INFO):
    global loggers
    if loggers.get(name) is not None:
        return loggers[name]
    else:
        logger = logging.getLogger(name)    # 生成器
        logger.setLevel(level)     # 设置日志级别   #生成器日志级别
        # Logging to console
        stream_handler = logging.StreamHandler(sys.stdout)   # 控制台句柄
        # 格式化对象
        formatter = logging.Formatter(
            '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
        stream_handler.setFormatter(formatter)    # 绑定格式化对象与控制台句柄
        logger.addHandler(stream_handler)         # 绑定生成器与控制台句柄

        loggers[name] = logger

        return logger

0.logger = get_logger('UNet3DTrain')
2.logger.info(config)

打印--config 输入的文件内容
运行截图:

2.logger.info(config)

9. model = get_model(config)
这一行比较简单,就是获取模型,可以改写model.py添加自己的模型。
5.logger.info(f'Seed the RNG for all devices with {manual_seed}')
14.logger.info(f"Sending the model to '{config['device']}'")
16.logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

运行截图:


Loss有点难,日后再补

17.loss_criterion = get_loss_criterion(config)


18.eval_criterion = get_evaluation_metric(config)
评价指标,没细看,估计也不太难,需要添加自己的指标


19.loaders = get_train_loaders(config)
这行代码调用了很多函数,简单来说就是返回了已经写好patch切片索引,可使用的data。下面的详解可能不对,只是我的浅显理解,欢迎批评指教
from pytorch3dunet.datasets.utils import get_train_loaders

def get_train_loaders(config):
    """
    Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader).

    :param config: a top level configuration object containing the 'loaders' key
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """
1     assert 'loaders' in config, 'Could not find data loaders configuration'
2     loaders_config = config['loaders']

3     logger.info('Creating training and validation set loaders...')

     # get dataset class
4     dataset_cls_str = loaders_config.get('dataset', None)  # StandardHDF5Dataset
5     if dataset_cls_str is None:
6         dataset_cls_str = 'StandardHDF5Dataset'
7         logger.warn(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
8     dataset_class = _get_cls(dataset_cls_str)

9     assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \
         "Train and validation 'file_paths' overlap. One cannot use validation data for training!"

10    train_datasets = dataset_class.create_datasets(loaders_config, phase='train')

11    val_datasets = dataset_class.create_datasets(loaders_config, phase='val')

12    num_workers = loaders_config.get('num_workers', 1)
13    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
14    batch_size = loaders_config.get('batch_size', 1)
15    if torch.cuda.device_count() > 1 and not config['device'].type == 'cpu':
16        logger.info(
            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
17        batch_size = batch_size * torch.cuda.device_count()

18    logger.info(f'Batch size for train/val loader: {batch_size}')
    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
19    return {
        'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True,
                            num_workers=num_workers),
        'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=True, num_workers=num_workers)
    }

1.assert 'loaders' in config, 'Could not find data loaders configuration'
2.loaders_config = config['loaders']
3.logger.info('Creating training and validation set loaders...')

获得config里loaders相关参数。
运行截图:

logger.info('Creating training and validation set loaders...')

4到8行得到以何种方式加载h5数据文件
8.dataset_class = _get_cls(dataset_cls_str)

def _get_cls(class_name):
    modules = ['pytorch3dunet.datasets.hdf5', 'pytorch3dunet.datasets.dsb', 'pytorch3dunet.datasets.utils']
    for module in modules:
        m = importlib.import_module(module)
        clazz = getattr(m, class_name, None)
        if clazz is not None:
            return clazz
    raise RuntimeError(f'Unsupported dataset class: {class_name}')

getattr(m, class_name)相当于m.class_name

clazz,_get_cls

9.assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']),
"Train and validation 'file_paths' overlap. One cannot use validation data for training!"

set.isdisjoint():用于判断两个集合是否包含相同的元素
即判断训练集和验证集是不是同一个数据集。
10.train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
调用hdf5.py的StandardHDF5Dataset的AbstractHDF5Dataset(ConfigDataset)的create_datasets。
获取train条件下,transformer,slice_builder,file_paths的配置。
file_paths可能包含文件和目录;如果file_paths是一个目录,那么其中的所有H5文件都将包含在最终的file_paths中

    def create_datasets(cls, dataset_config, phase):
        phase_config = dataset_config[phase]
        transformer_config = phase_config['transformer']
        slice_builder_config = phase_config['slice_builder']
        file_paths = phase_config['file_paths']
        #file_paths可能包含文件和目录;如果file_path是一个目录,那么其中的所有H5文件都将包含在最终的file_path中
        file_paths = cls.traverse_h5_paths(file_paths)

        datasets = []
        for file_path in file_paths:
            try:
                logger.info(f'Loading {phase} set from: {file_path}...')
                dataset = cls(file_path=file_path,
                              phase=phase,
                              slice_builder_config=slice_builder_config,
                              transformer_config=transformer_config,
                              mirror_padding=dataset_config.get('mirror_padding', None),
                              raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
                              label_internal_path=dataset_config.get('label_internal_path', 'label'),
                              weight_internal_path=dataset_config.get('weight_internal_path', None))
                datasets.append(dataset)
            except Exception:
                logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
        return datasets

    def traverse_h5_paths(file_paths):
        assert isinstance(file_paths, list) # 确保 file_paths是list类型
        results = []
        for file_path in file_paths:
            if os.path.isdir(file_path):
                # if file path is a directory take all H5 files in that directory
                iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
                for fp in chain(*iters):
                    results.append(fp)
            else:
                results.append(file_path)
        return results

调用完create_datasets后可能还顺便调用的init()?
如果是train,val不进行mirror_padding,如果是test进行mirror_padding,相当于给数据四周都加了个边
init()就不详细写了,对数据进行了相应的transform,并把每个patch的位置都标注出来了。
19行
返回了需要的DataLoader


20 optimizer = _create_optimizer(config, model)
这个比较简单,就是正常的optimizer


22 lr_scheduler = _create_lr_scheduler(config, optimizer)
这个比较简单,就是正常的lr_scheduler


23 trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)

def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders):
    assert 'trainer' in config, 'Could not find trainer configuration'
    trainer_config = config['trainer']
    # 中断后重新加载
    resume = trainer_config.get('resume', None)
    # 预训练
    pre_trained = trainer_config.get('pre_trained', None)
    # 忽略验证,即没有验证集
    skip_train_validation = trainer_config.get('skip_train_validation', False)

    # get tensorboard formatter
#并不知道是干啥的,不过里面包含了最大最小标准化
    tensorboard_formatter = get_tensorboard_formatter(trainer_config.get('tensorboard_formatter', None))

    if resume is not None:
        # continue training from a given checkpoint
        # 中断训练后继续训练
        return UNet3DTrainer.from_checkpoint(resume, model,
                                             optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, loaders, tensorboard_formatter=tensorboard_formatter)
    elif pre_trained is not None:
        # fine-tune a given pre-trained model
        # 对预训练的模型进行微调
        return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, device=config['device'], loaders=loaders,
                                             max_num_epochs=trainer_config['epochs'],
                                             max_num_iterations=trainer_config['iters'],
                                             validate_after_iters=trainer_config['validate_after_iters'],
                                             log_after_iters=trainer_config['log_after_iters'],
                                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                                             tensorboard_formatter=tensorboard_formatter,
                                             skip_train_validation=skip_train_validation)
    else:
        # start training from scratch
        # 从头开始训练
        return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                             config['device'], loaders, trainer_config['checkpoint_dir'],
                             max_num_epochs=trainer_config['epochs'],
                             max_num_iterations=trainer_config['iters'],
                             validate_after_iters=trainer_config['validate_after_iters'],
                             log_after_iters=trainer_config['log_after_iters'],
                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                             tensorboard_formatter=tensorboard_formatter,
                             skip_train_validation=skip_train_validation)

25 trainer.fit()

    def fit(self):
        for _ in range(self.num_epoch, self.max_num_epochs):
            # train for one epoch
            should_terminate = self.train(self.loaders['train'])

            if should_terminate:
                logger.info('Stopping criterion is satisfied. Finishing training')
                return

            self.num_epoch += 1
        logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")

train()里面就是正常的训练,根据验证集的评估指标保存模型。并且把learning_rate,loss_avg,eval_score_avg,预测图片等利用TensorBoardX保存下来。


预测

predict.py

predict.py和train.py整体差不多,对每个.nii预测的概率分别进行了保存,但没有进行评估。
emmm....其实后面对输出预测概率的处理没看太懂,只知道他移除了一些像素防止伪影,并且把经过mirror padding的预测图又还原了回去

def main():
    # Load configuration
    config = load_config()

    # Create the model
    model = get_model(config)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    # use DataParallel if more than 1 GPU available
    device = config['device']
    if torch.cuda.device_count() > 1 and not device.type == 'cpu':
        model = nn.DataParallel(model)
        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')

    logger.info(f"Sending the model to '{device}'")
    model = model.to(device)

    logger.info('Loading HDF5 datasets...')
    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        output_file = _get_output_file(test_loader.dataset)

        predictor = _get_predictor(model, test_loader, output_file, config)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        predictor.predict()

参考链接:
Python中的getattr()函数详解
Python Set isdisjoint() 方法

相关文章

  • pytorch-3dunet-master解读

    源代码:https://github.com/wolny/pytorch-3dunet Installation ...

  • 女娲文化之研究

    古代传说解读篇 道家文化解读篇 东西方文化对照解读篇 瑜伽能量解读篇 佛家文化解读篇 封神演义解读篇 现代解读意义篇

  • 破界突围之路:初探IPD流程(三)

    继上篇,IPD流程的源头,要通过三解读(标准解读、客户需求解读、场景解读),充分解读需求,再去适配产品,进入产品...

  • 不同的人学习同一个知识的分解和解读是不同的

    今天跟天赋解读师颜姐了解天赋解读,其实原先已经有另一个天赋解读师给解读过我自己的天赋。 上次的天赋解读师,只是解读...

  • 《把时间当做朋友》快读完了,记录今日份感受

    李笑来用最没有偏差的视角去解读自己,解读他人,解读生活,解读社会,解读世界。通过他的“时间”让我看到了正确...

  • 解读

    解读你 解读你的思想 解读你的一切 解读你不需要和时间赛跑 只要和你有过交谈、旅行、舞会, 就会知道“你” 解读你...

  • 【猜你喜欢】手淘流量,猜你喜欢行业案例解读

    猜你喜欢行业案例解读 目录: <女装行业案例解读1> 《女装行业案例解读2》 <美妆行业案例解读> <女装行业案例...

  • 书籍《原则》解读

    原则解读 1 拥抱现实,应对现实--“原则解读” (1) 心态+游戏化思维--原则解读(2)

  • 【Java源码计划】LongAdder<rt.jar_ja

    LongAdder 源码解读 源码解读部分按照我得理解翻译和解读注解并添加相关的部分代码解读 保持一个或者多个变量...

  • 领悟

    为什么解读文本困难?因为没有解读工具。为什么没有解读工具?因为读的书太少。怎样才可以拥有解读工具?读相关的理论书籍...

网友评论

      本文标题:pytorch-3dunet-master解读

      本文链接:https://www.haomeiwen.com/subject/bjnrtktx.html