源代码:https://github.com/wolny/pytorch-3dunet
Installation
1.使用conda安装
2.运行setup.py 安装
有关setup.py可以看
Python学习笔记|python之setuptools
setup如何打包
如何将自己的Python程序打包--setuptools详解
setup.py
from setuptools import setup, find_packages
# 获得__version__.py里的内容,使得获取到__version__
exec(open('pytorch3dunet/__version__.py').read())
setup(
name="pytorch3dunet", # 包名称------------生成的egg名称
# 自动动态获取packages,默认在和setup.py同一目录下搜索各个含有 init.py的包。exclude:打包的时,排除tests文件
packages=find_packages(exclude=["tests"]),
version=__version__, # (-V) 包版本----生成egg包的版本号
author="Adrian Wolny, Lorenzo Cerrone",
url="https://github.com/wolny/pytorch-3dunet", # 程序的官网地址
license="MIT",
python_requires='>=3.7' # --requires 定义依赖哪些模块
)
训练
train.py
-
def main():
0 logger = get_logger('UNet3DTrain')
def main():
# Load and log experiment configuration
1 config = load_config()
2 logger.info(config)
3 manual_seed = config.get('manual_seed', None)
4 if manual_seed is not None:
5 logger.info(f'Seed the RNG for all devices with {manual_seed}')
6 torch.manual_seed(manual_seed)
# see https://pytorch.org/docs/stable/notes/randomness.html
7 torch.backends.cudnn.deterministic = True
8 torch.backends.cudnn.benchmark = False
# Create the model
9 model = get_model(config)
# use DataParallel if more than 1 GPU available
10 device = config['device']
11 if torch.cuda.device_count() > 1 and not device.type == 'cpu':
12 model = nn.DataParallel(model)
13 logger.info(f'Using {torch.cuda.device_count()} GPUs for training')
# put the model on GPUs
14 logger.info(f"Sending the model to '{config['device']}'")
15 model = model.to(device)
# Log the number of learnable parameters
16 logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
# Create loss criterion
17 loss_criterion = get_loss_criterion(config)
# Create evaluation metric
18 eval_criterion = get_evaluation_metric(config)
# Create data loaders
19 loaders = get_train_loaders(config)
# Create the optimizer
20 optimizer = _create_optimizer(config, model)
# Create learning rate adjustment strategy
22 lr_scheduler = _create_lr_scheduler(config, optimizer)
# Create model trainer
23 trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
24 loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)
# Start training
25 trainer.fit()
1. config = load_config()
打印设备日志
![](https://img.haomeiwen.com/i7152393/927865394c2c7d93.png)
logger = utils.get_logger('ConfigLoader')
def load_config():
parser = argparse.ArgumentParser(description='UNet3D')
parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
args = parser.parse_args()
config = _load_config_yaml(args.config) # 打开--config的文件
# Get a device to train on
device_str = config.get('device', None)
if device_str is not None:
logger.info(f"Device specified in config: '{device_str}'")
-----------
if device_str.startswith('cuda') and not torch.cuda.is_available():
logger.warn('CUDA not available, using CPU')
device_str = 'cpu'
else:
device_str = "cuda:0" if torch.cuda.is_available() else 'cpu'
logger.info(f"Using '{device_str}' device")
-----------
device = torch.device(device_str)
config['device'] = device
def _load_config_yaml(config_file):
return yaml.safe_load(open(config_file, 'r'))
--------------------------------------------------------------------------------------------------------------------------
import logging
loggers = {}
def get_logger(name, level=logging.INFO):
global loggers
if loggers.get(name) is not None:
return loggers[name]
else:
logger = logging.getLogger(name) # 生成器
logger.setLevel(level) # 设置日志级别 #生成器日志级别
# Logging to console
stream_handler = logging.StreamHandler(sys.stdout) # 控制台句柄
# 格式化对象
formatter = logging.Formatter(
'%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
stream_handler.setFormatter(formatter) # 绑定格式化对象与控制台句柄
logger.addHandler(stream_handler) # 绑定生成器与控制台句柄
loggers[name] = logger
return logger
0.logger = get_logger('UNet3DTrain')
2.logger.info(config)
打印--config 输入的文件内容
运行截图:
![](https://img.haomeiwen.com/i7152393/43f1a19bd637eda8.png)
9. model = get_model(config)
这一行比较简单,就是获取模型,可以改写model.py添加自己的模型。
5.logger.info(f'Seed the RNG for all devices with {manual_seed}')
14.logger.info(f"Sending the model to '{config['device']}'")
16.logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
运行截图:
![](https://img.haomeiwen.com/i7152393/7b6d02ab79cef793.png)
Loss有点难,日后再补
17.loss_criterion = get_loss_criterion(config)
18.eval_criterion = get_evaluation_metric(config)
评价指标,没细看,估计也不太难,需要添加自己的指标
19.loaders = get_train_loaders(config)
这行代码调用了很多函数,简单来说就是返回了已经写好patch切片索引,可使用的data。下面的详解可能不对,只是我的浅显理解,欢迎批评指教
from pytorch3dunet.datasets.utils import get_train_loaders
def get_train_loaders(config):
"""
Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader).
:param config: a top level configuration object containing the 'loaders' key
:return: dict {
'train': <train_loader>
'val': <val_loader>
}
"""
1 assert 'loaders' in config, 'Could not find data loaders configuration'
2 loaders_config = config['loaders']
3 logger.info('Creating training and validation set loaders...')
# get dataset class
4 dataset_cls_str = loaders_config.get('dataset', None) # StandardHDF5Dataset
5 if dataset_cls_str is None:
6 dataset_cls_str = 'StandardHDF5Dataset'
7 logger.warn(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
8 dataset_class = _get_cls(dataset_cls_str)
9 assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \
"Train and validation 'file_paths' overlap. One cannot use validation data for training!"
10 train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
11 val_datasets = dataset_class.create_datasets(loaders_config, phase='val')
12 num_workers = loaders_config.get('num_workers', 1)
13 logger.info(f'Number of workers for train/val dataloader: {num_workers}')
14 batch_size = loaders_config.get('batch_size', 1)
15 if torch.cuda.device_count() > 1 and not config['device'].type == 'cpu':
16 logger.info(
f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
17 batch_size = batch_size * torch.cuda.device_count()
18 logger.info(f'Batch size for train/val loader: {batch_size}')
# when training with volumetric data use batch_size of 1 due to GPU memory constraints
19 return {
'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True,
num_workers=num_workers),
'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=True, num_workers=num_workers)
}
1.assert 'loaders' in config, 'Could not find data loaders configuration'
2.loaders_config = config['loaders']
3.logger.info('Creating training and validation set loaders...')
获得config里loaders相关参数。
运行截图:
![](https://img.haomeiwen.com/i7152393/577f29dabc469fe3.png)
4到8行得到以何种方式加载h5数据文件
8.dataset_class = _get_cls(dataset_cls_str)
def _get_cls(class_name):
modules = ['pytorch3dunet.datasets.hdf5', 'pytorch3dunet.datasets.dsb', 'pytorch3dunet.datasets.utils']
for module in modules:
m = importlib.import_module(module)
clazz = getattr(m, class_name, None)
if clazz is not None:
return clazz
raise RuntimeError(f'Unsupported dataset class: {class_name}')
getattr(m, class_name)相当于m.class_name
![](https://img.haomeiwen.com/i7152393/758863fadccc2db2.png)
9.assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']),
"Train and validation 'file_paths' overlap. One cannot use validation data for training!"
set.isdisjoint():用于判断两个集合是否包含相同的元素
即判断训练集和验证集是不是同一个数据集。
10.train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
调用hdf5.py的StandardHDF5Dataset的AbstractHDF5Dataset(ConfigDataset)的create_datasets。
获取train条件下,transformer,slice_builder,file_paths的配置。
file_paths可能包含文件和目录;如果file_paths是一个目录,那么其中的所有H5文件都将包含在最终的file_paths中
def create_datasets(cls, dataset_config, phase):
phase_config = dataset_config[phase]
transformer_config = phase_config['transformer']
slice_builder_config = phase_config['slice_builder']
file_paths = phase_config['file_paths']
#file_paths可能包含文件和目录;如果file_path是一个目录,那么其中的所有H5文件都将包含在最终的file_path中
file_paths = cls.traverse_h5_paths(file_paths)
datasets = []
for file_path in file_paths:
try:
logger.info(f'Loading {phase} set from: {file_path}...')
dataset = cls(file_path=file_path,
phase=phase,
slice_builder_config=slice_builder_config,
transformer_config=transformer_config,
mirror_padding=dataset_config.get('mirror_padding', None),
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None))
datasets.append(dataset)
except Exception:
logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
return datasets
def traverse_h5_paths(file_paths):
assert isinstance(file_paths, list) # 确保 file_paths是list类型
results = []
for file_path in file_paths:
if os.path.isdir(file_path):
# if file path is a directory take all H5 files in that directory
iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
for fp in chain(*iters):
results.append(fp)
else:
results.append(file_path)
return results
调用完create_datasets后可能还顺便调用的init()?
如果是train,val不进行mirror_padding,如果是test进行mirror_padding,相当于给数据四周都加了个边
init()就不详细写了,对数据进行了相应的transform,并把每个patch的位置都标注出来了。
19行
返回了需要的DataLoader
20 optimizer = _create_optimizer(config, model)
这个比较简单,就是正常的optimizer
22 lr_scheduler = _create_lr_scheduler(config, optimizer)
这个比较简单,就是正常的lr_scheduler
23 trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders)
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders):
assert 'trainer' in config, 'Could not find trainer configuration'
trainer_config = config['trainer']
# 中断后重新加载
resume = trainer_config.get('resume', None)
# 预训练
pre_trained = trainer_config.get('pre_trained', None)
# 忽略验证,即没有验证集
skip_train_validation = trainer_config.get('skip_train_validation', False)
# get tensorboard formatter
#并不知道是干啥的,不过里面包含了最大最小标准化
tensorboard_formatter = get_tensorboard_formatter(trainer_config.get('tensorboard_formatter', None))
if resume is not None:
# continue training from a given checkpoint
# 中断训练后继续训练
return UNet3DTrainer.from_checkpoint(resume, model,
optimizer, lr_scheduler, loss_criterion,
eval_criterion, loaders, tensorboard_formatter=tensorboard_formatter)
elif pre_trained is not None:
# fine-tune a given pre-trained model
# 对预训练的模型进行微调
return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
eval_criterion, device=config['device'], loaders=loaders,
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
tensorboard_formatter=tensorboard_formatter,
skip_train_validation=skip_train_validation)
else:
# start training from scratch
# 从头开始训练
return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
config['device'], loaders, trainer_config['checkpoint_dir'],
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
tensorboard_formatter=tensorboard_formatter,
skip_train_validation=skip_train_validation)
25 trainer.fit()
def fit(self):
for _ in range(self.num_epoch, self.max_num_epochs):
# train for one epoch
should_terminate = self.train(self.loaders['train'])
if should_terminate:
logger.info('Stopping criterion is satisfied. Finishing training')
return
self.num_epoch += 1
logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")
train()里面就是正常的训练,根据验证集的评估指标保存模型。并且把learning_rate,loss_avg,eval_score_avg,预测图片等利用TensorBoardX保存下来。
预测
predict.py
predict.py和train.py整体差不多,对每个.nii预测的概率分别进行了保存,但没有进行评估。
emmm....其实后面对输出预测概率的处理没看太懂,只知道他移除了一些像素防止伪影,并且把经过mirror padding的预测图又还原了回去
def main():
# Load configuration
config = load_config()
# Create the model
model = get_model(config)
# Load model state
model_path = config['model_path']
logger.info(f'Loading model from {model_path}...')
utils.load_checkpoint(model_path, model)
# use DataParallel if more than 1 GPU available
device = config['device']
if torch.cuda.device_count() > 1 and not device.type == 'cpu':
model = nn.DataParallel(model)
logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
logger.info(f"Sending the model to '{device}'")
model = model.to(device)
logger.info('Loading HDF5 datasets...')
for test_loader in get_test_loaders(config):
logger.info(f"Processing '{test_loader.dataset.file_path}'...")
output_file = _get_output_file(test_loader.dataset)
predictor = _get_predictor(model, test_loader, output_file, config)
# run the model prediction on the entire dataset and save to the 'output_file' H5
predictor.predict()
网友评论