nnUnet代码解读–模型训练

nnunet项目官方地址

MIC-DKFZ/nnUNet

准备工作

关于nnUnet代码包的安装和配置参考nn-UNet使用记录–代码配置_宁眸的博客-CSDN博客

确认安装nnUnet代码包,配置好环境变量,确认nnunet/run/run_training.py可以成功运行。

我是一边debug一边阅读源码的,不想debug也可以直接看源码。

export nnUNet_raw_data_base="/data/.../nnUNet_raw_data_base"
export nnUNet_preprocessed="/data/.../nnUNet_preprocessed"
export RESULTS_FOLDER="/data/.../nnUNet_trained_models"

下面我将以三维图像分割中的肾脏肿瘤分割(KiTS19)为例,具体介绍nnunet/run/run_training.py 的执行过程。

当前目录:

/data/.../nnUNet

命令行输入:
python nnunet/run/run_training.py 3d_fullres nnUNetTrainerV2 40 1

3d_fullres 是网络名字,nnUNetTrainerV2是训练器,40是任务ID,1是交叉验证的序号。


1.添加代码

import sys
sys.path.append("/data/.../nnUNet/")    # 添加nnUNet的绝对目录到环境变量中

nnUnet使用过程中若出现"no module named 'nnunet'"报错,则在顶部添加上面两行代码。

# 在服务器上debug
import ptvsd 
ptvsd.enable_attach(address =('omnisky',5678))
ptvsd.wait_for_attach()

使用pvtsd调试代码,直观的查看程序执行过程中的变量,辅助阅读源码。


run_training.py代码精读

2.导入相关包

import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.run.default_configuration import get_default_configuration
from nnunet.paths import default_plans_identifier
from nnunet.run.load_pretrained_weights import load_pretrained_weights
from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name

3.设置训练参数

parser = argparse.ArgumentParser()
parser.add_argument("network")  # 3d_fullres
parser.add_argument("network_trainer")  # nnUNetTrainerV2
parser.add_argument("task", help="can be task name or task id")  # 40
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')  # 1
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
                    action="store_true")
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
                    action="store_true")
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
                    default=default_plans_identifier, required=False)
parser.add_argument("--use_compressed_data", default=False, action="store_true",
                    help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
                    "is much more CPU and RAM intensive and should only be used if you know what you are "
                    "doing", required=False)
parser.add_argument("--deterministic",
                    help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
                    "this is not necessary. Deterministic training will make you overfit to some random seed. "
                    "Don't use that.",
                    required=False, default=False, action="store_true")
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
                    "export npz files of "
                    "predicted segmentations "
                    "in the validation as well. "
                    "This is needed to run the "
                    "ensembling step so unless "
                    "you are developing nnUNet "
                    "you should enable this")
parser.add_argument("--find_lr", required=False, default=False, action="store_true",
                    help="not used here, just for fun")
parser.add_argument("--valbest", required=False, default=False, action="store_true",
                    help="hands off. This is not intended to be used")
parser.add_argument("--fp32", required=False, default=False, action="store_true",
                    help="disable mixed precision training and run old school fp32")
parser.add_argument("--val_folder", required=False, default="validation_raw",
                    help="name of the validation folder. No need to use this for most people")
parser.add_argument("--disable_saving", required=False, action='store_true',
                    help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
                    "will be removed at the end of the training). Useful for development when you are "
                    "only interested in the results and want to save some disk space")
parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
                    help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
                    "closely observing the model performance on specific configurations. You do not need it "
                    "when applying nnU-Net because the postprocessing for this will be determined only once "
                    "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
                    "running postprocessing on each fold is computationally cheap, but some users have "
                    "reported issues with very large images. If your images are large (>600x600x600 voxels) "
                    "you should consider setting this flag.")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
#                     help="order of interpolation for segmentations. Testing purpose only. Hands off")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
#                     help="order of interpolation along z if z is resampled separately. Testing purpose only. "
#                          "Hands off")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
#                     help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
                    help='Validation does not overwrite existing segmentations')
parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
                    help='do not predict next stage')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
                    help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
                    'file, for example model_final_checkpoint.model). Will only be used when actually training. '
                    'Optional. Beta. Use with caution.')

args = parser.parse_args()
  • positional arguments:
  • network
  • network_trainer
  • task can be task name or task id
  • fold 0, 1, …, 5 or ‘all’
  • optional arguments:
  • -h, –help show this help message and exit
  • -val, –validation_only use this if you want to only run the validation
  • -c, –continue_training use this if you want to continue a training
  • -p plans identifier. Only change this if you created a custom experiment planner
  • –use_compressed_data If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data is much more CPU and RAM intensive and should only be used if you know what you are doing
  • –deterministic Makes training deterministic, but reduces training speed substantially. I (Fabian) think this is not necessary. Deterministic training will make you overfit to some random seed. Don’t use that.
  • –npz if set then nnUNet will export npz files of predicted segmentations in the validation as well. This is needed to run the ensembling step so unless you are developing nnUNet you should enable this
  • –find_lr not used here, just for fun
  • –valbest hands off. This is not intended to be used
  • –fp32 disable mixed precision training and run old school fp32
  • –val_folder name of the validation folder. No need to use this for most people
  • –disable_saving If set nnU-Net will not save any parameter files (except a temporary checkpoint that will be removed at the end of the training). Useful for development when you are only interested in the results and want to save some disk space
  • –disable_postprocessing_on_folds Running postprocessing on each fold only makes sense when developing with nnU-Net and closely observing the model performance on specific configurations. You do not need it when applying nnU-Net because the postprocessing for this will be determined only once all five folds have been trained and nnUNet_find_best_configuration is called. Usually running postprocessing on each fold is computationally cheap, but some users have reported issues with very large images. If your images are large (>600x600x600 voxels) you should consider setting this flag.
  • –val_disable_overwrite Validation does not overwrite existing segmentations
  • –disable_next_stage_pred do not predict next stage
  • -pretrained_weights path to nnU-Net checkpoint file to be used as pretrained model (use .model file, for example model_final_checkpoint.model). Will only be used when actually training. Optional. Beta. Use with caution.
  • 上面的配置参数非常重要,作者也写地非常详细,大家最好阅读几遍。

    task = args.task  # 40
    fold = args.fold  # 1
    network = args.network  # 3d_fullres
    network_trainer = args.network_trainer  # nnUNetTrainerV2
    validation_only = args.validation_only  # False
    plans_identifier = args.p  # 'nnUNetPlansv2.1'
    find_lr = args.find_lr  # False
    disable_postprocessing_on_folds = args.disable_postprocessing_on_folds  # False
    
    use_compressed_data = args.use_compressed_data  # False
    decompress_data = not use_compressed_data  # True
    
    deterministic = args.deterministic  # False
    valbest = args.valbest  # False
    
    fp32 = args.fp32  # False,使用fp16
    run_mixed_precision = not fp32  # 混合精度训练
    
    val_folder = args.val_folder  # 'validation_raw',验证集输出目录
    
    if not task.startswith("Task"):
        task_id = int(task)
        task = convert_id_to_task_name(task_id)  # 'Task040_KiTS'
    
        if fold == 'all':
            pass
        else:
            fold = int(fold)
    

    4.获取默认配置

    get_default_configuration

    plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
    

    输入:

  • network: 3d_fullres
  • task: Task040_KiTS
  • network_trainer: nnUNetTrainerV2
  • plans_identifier: nnUNetPlansv2.1
  • nnUNetPlansv2.1对应的文件是
    nnUNet_preprocessed/Task040_KiTS/nnUNetPlansv2.1_plans_3D.pkl

    def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier,
                                  search_in=(nnunet.__path__[0], "training", "network_training"),
                                  base_module='nnunet.training.network_training'):
        assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \
            "network can only be one of the following: \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'"
    
        dataset_directory = join(preprocessing_output_dir, task)  # 获取数据集目录
    
        if network == '2d':
            plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl")
        else:
            plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl")
    
        plans = load_pickle(plans_file)  # 加载训练plan
        possible_stages = list(plans['plans_per_stage'].keys())
    
        if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1:
            raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does "
                               "not require the cascade. Run 3d_fullres instead")
    
        if network == '2d' or network == "3d_lowres":
            stage = 0
        else:
            stage = possible_stages[-1]
    
        trainer_class = recursive_find_python_class([join(*search_in)], network_trainer,
                                                    current_module=base_module)  # <class 'nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2'>
    
        output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier)
    
        print("###############################################")
        print("I am running the following nnUNet: %s" % network)
        print("My trainer class is: ", trainer_class)
        print("For that I will be using the following configuration:")
        summarize_plans(plans_file)
        print("I am using stage %d from these plans" % stage)
    
        if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres':
            batch_dice = True
            print("I am using batch dice + CE loss")
        else:
            batch_dice = False
            print("I am using sample dice + CE loss")
    
        print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier']))
        print("###############################################")
        return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class
    

    配置函数中最关键的一步:

    trainer_class = recursive_find_python_class([join(*search_in)], network_trainer, current_module=base_module)
    

    定义训练类:
    <class ‘nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2’>

    # folder: '/data/omnisky/postgraduate/Yb/nnUNet/nnunet\training\network_training'
    # trainer_name: 'nnUNetTrainerV2'
    # current_module: 'nnunet.training.network_training'
    def recursive_find_python_class(folder, trainer_name, current_module):
        tr = None
        for importer, modname, ispkg in pkgutil.iter_modules(folder):
            # print(modname, ispkg)
            if not ispkg:
                m = importlib.import_module(current_module + "." + modname)
                if hasattr(m, trainer_name):
                    tr = getattr(m, trainer_name)
                    break
    
        if tr is None:
            for importer, modname, ispkg in pkgutil.iter_modules(folder):
                if ispkg:
                    next_current_module = current_module + "." + modname
                    tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module)
                if tr is not None:
                    break
    
        return tr
    # ----------------------------------------------------------------------------------------------------
    # 返回值: 'nnUNetTrainerV2'
    

    get_default_configuration执行完毕后,终端输出以下内容:

    ###############################################
    I am running the following nnUNet: 3d_fullres
    My trainer class is:  <class 'nnunet.training.network_training.nnUNetTrainerV2.nnUNetTrainerV2'>
    For that I will be using the following configuration:
    num_classes:  2
    modalities:  {0: 'CT'}
    use_mask_for_norm OrderedDict([(0, False)])
    keep_only_largest_region None
    min_region_size_per_class None
    min_size_per_class None
    normalization_schemes OrderedDict([(0, 'CT')])
    stages...
    
    stage:  0
    {'batch_size': 3, 'num_pool_per_axis': [4, 5, 5], 'patch_size': array([ 80, 160, 160]), 'median_patient_size_in_voxels': array([128, 247, 247]), 'current_spacing': array([3.22000003, 1.62      , 1.62      ]), 'original_spacing': array([3.22000003, 1.62      , 1.62      ]), 'do_dummy_2D_data_aug': False, 'pool_op_kernel_sizes': [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], 'conv_kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]}
    
    I am using stage 0 from these plans
    I am using sample dice + CE loss
    
    I am using data from this folder:  /data/.../nnUNet_preprocessed/Task040_KiTS/nnUNetData_plans_v2.1
    

    5.创建训练类

    回到run_training.py

    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
             batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,deterministic=deterministic,fp16=run_mixed_precision)
    

    这里说明一下类的继承:首先,train_class函数是nnUNetTrainerV2类,nnUNetTrainerV2.py文件中的nnUNetTrainerV2类继承自nnUNetTrainer.py中的nnUNetTrainer类,nnUNetTrainer类又继承自network_trainer.pyNetworkTrainer类。

    train_class的创建过程:

    NetworkTrainer > nnUNetTrainer > nnUNetTrainerV2

    因此,如果想修改训练过程中的超参数(‘epochs’,‘optimizer’,'learning_rate’等),建议直接在最底层的nnUNetTrainerV2类中修改。

    以下三个代码段分别是NetworkTrainer类,nnUNetTrainer类,nnUNetTrainerV2类的初始化部分,建议结合nnUnet训练结果中debug.json文件查看。

    重点观察initialize, initialize_network, initialize_optimizer_and_scheduler三个函数的变化。

    NetworkTrainer:一个通用的神经网络训练类,适用于除RNN外几乎所有的模型。

    class NetworkTrainer(object):
        def __init__(self, deterministic=True, fp16=False):
            """
            A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
            as the training loop, tracking of training and validation losses (and the target metric if you implement it)
            Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
            anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
            results.
    
            What you need to override:
            - __init__
            - initialize
            - run_online_evaluation (optional)
            - finish_online_evaluation (optional)
            - validate
            - predict_test_case
            """
            self.fp16 = fp16
            self.amp_grad_scaler = None
    
            if deterministic:
                np.random.seed(12345)
                torch.manual_seed(12345)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed_all(12345)
                cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False
            else:
                cudnn.deterministic = False
                torch.backends.cudnn.benchmark = True
    
            ################# SET THESE IN self.initialize() ###################################
            self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
            self.optimizer = None
            self.lr_scheduler = None
            self.tr_gen = self.val_gen = None
            self.was_initialized = False
    
            ################# SET THESE IN INIT ################################################
            self.output_folder = None
            self.fold = None
            self.loss = None
            self.dataset_directory = None
    
            ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
            self.dataset = None  # these can be None for inference mode
            self.dataset_tr = self.dataset_val = None  # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
    
            ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
            self.patience = 50
            self.val_eval_criterion_alpha = 0.9  # alpha * old + (1-alpha) * new
            # if this is too low then the moving average will be too noisy and the training may terminate early. If it is
            # too high the training will take forever
            self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new
            self.train_loss_MA_eps = 5e-4  # new MA must be at least this much better (smaller)
            self.max_num_epochs = 1000
            self.num_batches_per_epoch = 250
            self.num_val_batches_per_epoch = 50
            self.also_val_in_tr_mode = False
            self.lr_threshold = 1e-6  # the network will not terminate training if the lr is still above this threshold
    
            ################# LEAVE THESE ALONE ################################################
            self.val_eval_criterion_MA = None
            self.train_loss_MA = None
            self.best_val_eval_criterion_MA = None
            self.best_MA_tr_loss_for_patience = None
            self.best_epoch_based_on_MA_tr_loss = None
            self.all_tr_losses = []
            self.all_val_losses = []
            self.all_val_losses_tr_mode = []
            self.all_val_eval_metrics = []  # does not have to be used
            self.epoch = 0
            self.log_file = None
            self.deterministic = deterministic
    
            self.use_progress_bar = False
            if 'nnunet_use_progress_bar' in os.environ.keys():
                self.use_progress_bar = bool(int(os.environ['nnunet_use_progress_bar']))
    
            ################# Settings for saving checkpoints ##################################
            self.save_every = 50
            self.save_latest_only = True  # if false it will not store/overwrite _latest but separate files each
            # time an intermediate checkpoint is created
            self.save_intermediate_checkpoints = True  # whether or not to save checkpoint_latest
            self.save_best_checkpoint = True  # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
            self.save_final_checkpoint = True  # whether or not to save the final checkpoint
    
        @abstractmethod
        def initialize(self, training=True):
            """
            create self.output_folder
    
            modify self.output_folder if you are doing cross-validation (one folder per fold)
    
            set self.tr_gen and self.val_gen
    
            call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
    
            finally set self.was_initialized to True
            :param training:
            :return:
            """
        @abstractmethod
        def initialize_network(self):
            """
            initialize self.network here
            :return:
            """
            pass
        @abstractmethod
        def initialize_optimizer_and_scheduler(self):
            """
            initialize self.optimizer and self.lr_scheduler (if applicable) here
            :return:
            """
            pass
    

    nnUNetTrainer:通用的UNet训练类,主要根据plan.pkl文件生成训练信息,还有一些默认参数。

    class nnUNetTrainer(NetworkTrainer):
        def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
                     unpack_data=True, deterministic=True, fp16=False):
            """
            :param deterministic:
            :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
            None if you wish to load some checkpoint and do inference only
            :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
            :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
            not the entire path). This is where the preprocessed data lies that will be used for network training. We made
            this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
            Can be None if you are doing inference only.
            :param output_folder: where to store parameters, plot progress and to the validation
            :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
            because the split information is stored in this directory. For running prediction only this input is not
            required and may be set to None
            :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
            batch is a pseudo volume?
            :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
            specified for training:
            if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
            :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
            is considerably slower! Running unpack_data=False with 2d should never be done!
    
            IMPORTANT: If you inherit from nnUNetTrainer and the init args change then you need to redefine self.init_args
            in your init accordingly. Otherwise checkpoints won't load properly!
            """
            super(nnUNetTrainer, self).__init__(deterministic, fp16)
            self.unpack_data = unpack_data
            self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
                              deterministic, fp16)
            # set through arguments from init
            self.stage = stage
            self.experiment_name = self.__class__.__name__
            self.plans_file = plans_file
            self.output_folder = output_folder
            self.dataset_directory = dataset_directory
            self.output_folder_base = self.output_folder
            self.fold = fold
    
            self.plans = None
    
            # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
            # irrelevant
            if self.dataset_directory is not None and isdir(self.dataset_directory):
                self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
            else:
                self.gt_niftis_folder = None
    
            self.folder_with_preprocessed_data = None
    
            # set in self.initialize()
    
            self.dl_tr = self.dl_val = None
            self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
                self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
                self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None  # loaded automatically from plans_file
            self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
    
            self.batch_dice = batch_dice
            self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
    
            self.online_eval_foreground_dc = []
            self.online_eval_tp = []
            self.online_eval_fp = []
            self.online_eval_fn = []
    
            self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
                self.min_region_size_per_class = self.min_size_per_class = None
    
            self.inference_pad_border_mode = "constant"
            self.inference_pad_kwargs = {'constant_values': 0}
    
            self.update_fold(fold)
            self.pad_all_sides = None
    
            self.lr_scheduler_eps = 1e-3
            self.lr_scheduler_patience = 30
            self.initial_lr = 3e-4
            self.weight_decay = 3e-5
    
            self.oversample_foreground_percent = 0.33
    
            self.conv_per_stage = None
            self.regions_class_order = None
         
            def initialize(self, training=True, force_load_plans=False):
            """
            For prediction of test cases just set training=False, this will prevent loading of training data and
            training batchgenerator initialization
            :param training:
            :return:
            """
    
            maybe_mkdir_p(self.output_folder)
    
            if force_load_plans or (self.plans is None):
                self.load_plans_file()
    
            self.process_plans(self.plans)
    
            self.setup_DA_params()
    
            if training:
                self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
                                                          "_stage%d" % self.stage)
    
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    self.print_to_log_file("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    self.print_to_log_file("done")
                else:
                    self.print_to_log_file(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")
                self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[
                                                                         'patch_size_for_spatialtransform'],
                                                                     self.data_aug_params)
                self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass
            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
            self.was_initialized = True
    
        def initialize_network(self):
            """
            This is specific to the U-Net and must be adapted for other network architectures
            :return:
            """
            # self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
            # self.print_to_log_file(self.net_conv_kernel_sizes)
    
            net_numpool = len(self.net_num_pool_op_kernel_sizes)
    
            if self.threeD:
                conv_op = nn.Conv3d
                dropout_op = nn.Dropout3d
                norm_op = nn.InstanceNorm3d
            else:
                conv_op = nn.Conv2d
                dropout_op = nn.Dropout2d
                norm_op = nn.InstanceNorm2d
    
            norm_op_kwargs = {'eps': 1e-5, 'affine': True}
            dropout_op_kwargs = {'p': 0, 'inplace': True}
            net_nonlin = nn.LeakyReLU
            net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
            self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
                                        self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
                                        net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
                                        self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
            self.network.inference_apply_nonlin = softmax_helper
    
            if torch.cuda.is_available():
                self.network.cuda()
    
        def initialize_optimizer_and_scheduler(self):
            assert self.network is not None, "self.initialize_network must be called first"
            self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True)
            self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2, patience=self.lr_scheduler_patience,
                                                               verbose=True, threshold=self.lr_scheduler_eps, threshold_mode="abs")
    

    **nnUNetTrainerV2:**主要在nnUNetTrainer的基础上做了2个改进——多样化数据增强、深监督损失函数

    class nnUNetTrainerV2(nnUNetTrainer):
        """
        Info for Fabian: same as internal nnUNetTrainerV2_2
        """
    
        def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
                     unpack_data=True, deterministic=True, fp16=False):
            super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
                             deterministic, fp16)
            self.max_num_epochs = 1000
            self.initial_lr = 1e-2
            self.deep_supervision_scales = None
            self.ds_loss_weights = None
    
            self.pin_memory = True
    
        def initialize(self, training=True, force_load_plans=False):
            """
            - replaced get_default_augmentation with get_moreDA_augmentation
            - enforce to only run this code once
            - loss function wrapper for deep supervision
    
            :param training:
            :param force_load_plans:
            :return:
            """
            if not self.was_initialized:
                maybe_mkdir_p(self.output_folder)
    
                if force_load_plans or (self.plans is None):
                    self.load_plans_file()
    
                self.process_plans(self.plans)
    
                self.setup_DA_params()
    
                ################# Here we wrap the loss for deep supervision ############
                # we need to know the number of outputs of the network
                net_numpool = len(self.net_num_pool_op_kernel_sizes)
    
                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
    
                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights
                # now wrap the loss
                self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
                ################# END ###################
    
                self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
                                                          "_stage%d" % self.stage)
                if training:
                    self.dl_tr, self.dl_val = self.get_basic_generators()
                    if self.unpack_data:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    else:
                        print(
                            "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                            "will wait all winter for your model to finish!")
    
                    self.tr_gen, self.val_gen = get_moreDA_augmentation(
                        self.dl_tr, self.dl_val,
                        self.data_aug_params[
                            'patch_size_for_spatialtransform'],
                        self.data_aug_params,
                        deep_supervision_scales=self.deep_supervision_scales,
                        pin_memory=self.pin_memory,
                        use_nondetMultiThreadedAugmenter=False
                    )
                    self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
                                           also_print_to_console=False)
                    self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
                                           also_print_to_console=False)
                else:
                    pass
    
                self.initialize_network()
                self.initialize_optimizer_and_scheduler()
    
                assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
            else:
                self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
            self.was_initialized = True
    
        def initialize_network(self):
            """
            - momentum 0.99
            - SGD instead of Adam
            - self.lr_scheduler = None because we do poly_lr
            - deep supervision = True
            - i am sure I forgot something here
    
            Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
            :return:
            """
            if self.threeD:
                conv_op = nn.Conv3d
                dropout_op = nn.Dropout3d
                norm_op = nn.InstanceNorm3d
    
            else:
                conv_op = nn.Conv2d
                dropout_op = nn.Dropout2d
                norm_op = nn.InstanceNorm2d
    
            norm_op_kwargs = {'eps': 1e-5, 'affine': True}
            dropout_op_kwargs = {'p': 0, 'inplace': True}
            net_nonlin = nn.LeakyReLU
            net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
            self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
                                        len(self.net_num_pool_op_kernel_sizes),
                                        self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
                                        dropout_op_kwargs,
                                        net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
                                        self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
            if torch.cuda.is_available():
                self.network.cuda()
            self.network.inference_apply_nonlin = softmax_helper
    
        def initialize_optimizer_and_scheduler(self):
            assert self.network is not None, "self.initialize_network must be called first"
            self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                             momentum=0.99, nesterov=True)
            self.lr_scheduler = None
    

    initialize, initialize_network, initialize_optimizer_and_scheduler三个函数一层一层被覆盖掉,如果使用的是nnUNetTrainerV2训练器,只用看nnUNetTrainerV2类中的初始化过程(不过有些默认参数是在nnUNetTrainerNetworkTrainer中定义的,nnUNetTrainerV2中没有出现过)。

    注释写的非常详细,我这里先把网络架构搞清楚,里面具体的细节以后可以慢慢看。


    6.网络模型

    注意UNet网络模型是在nnUNetTrainerV2类中的initialize_network定义的

    self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
                                        len(self.net_num_pool_op_kernel_sizes),
                                        self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
                                        dropout_op_kwargs,
                                        net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
                                        self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
    

    以我的一个模型为例:

  • num_input_channels=1
  • base_num_features=32
  • num_classes=3,
  • net_num_pool_op_kernel_sizes= [[2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]]
  • conv_per_stage=2
  • feat_map_mul_on_downscale=2
  • conv_op=nn.Conv2d
  • norm_op=nn.BatchNorm2d
  • norm_op_kwargs={‘eps’: 1e-5, ‘affine’: True}
  • dropout_op=nn.Dropout2d
  • dropout_op_kwargs={‘p’: 0, ‘inplace’: True}
  • nonlin=nn.LeakyReLU
  • nonlin_kwargs={‘negative_slope’: 1e-2, ‘inplace’: True}
  • deep_supervision=True
  • dropout_in_localization=False
  • final_nonlin=lambda x: x
  • weightInitializer=InitWeights_He(1e-2)
  • net_conv_kernel_sizes=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
  • upscale_logits=False
  • convolutional_pooling=False
  • convolutional_upsampling=False
  • 创建网络模型也是一环套一环的,generic_UNet.py文件中的Generic_UNet类继承自neural_network.py文件中的SegmentationNetwork类,SegmentationNetwork类又继承自当前文件中的NeuralNetwork类。

    network创建过程:
    NeuralNetwork > SegmentationNetwork > Generic_UNet

    如果想用自己的模型去训练,是不是可以直接把网络替换掉,self.network=MyModel,然后就可以用nnUnet训练自己的模型了。

    我把NeuralNetwork类,SegmentationNetwork类,Generic_UNet类的部分代码放在了下面

    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
    
        def get_device(self):
            if next(self.parameters()).device.type == "cpu":
                return "cpu"
            else:
                return next(self.parameters()).device.index
    
        def set_device(self, device):
            if device == "cpu":
                self.cpu()
            else:
                self.cuda(device)
    
        def forward(self, x):
            raise NotImplementedError
    
    
    class SegmentationNetwork(NeuralNetwork):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
    
            # if we have 5 pooling then our patch size must be divisible by 2**5
            self.input_shape_must_be_divisible_by = None  # for example in a 2d network that does 5 pool in x and 6 pool
            # in y this would be (32, 64)
    
            # we need to know this because we need to know if we are a 2d or a 3d netowrk
            self.conv_op = None  # nn.Conv2d or nn.Conv3d
    
            # this tells us how many channels we have in the output. Important for preallocation in inference
            self.num_classes = None  # number of channels in the output
    
            # depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions
            # during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what
            # to apply in inference. For the most part this will be softmax
            self.inference_apply_nonlin = lambda x: x  # softmax_helper
    
            # This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the
            # center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians
            # can be expensive, so it makes sense to save and reuse them.
            self._gaussian_3d = self._patch_size_for_gaussian_3d = None
            self._gaussian_2d = self._patch_size_for_gaussian_2d = None
    
    class Generic_UNet(SegmentationNetwork):
        DEFAULT_BATCH_SIZE_3D = 2
        DEFAULT_PATCH_SIZE_3D = (64, 192, 160)
        SPACING_FACTOR_BETWEEN_STAGES = 2
        BASE_NUM_FEATURES_3D = 30
        MAX_NUMPOOL_3D = 999
        MAX_NUM_FILTERS_3D = 320
    
        DEFAULT_PATCH_SIZE_2D = (256, 256)
        BASE_NUM_FEATURES_2D = 30
        DEFAULT_BATCH_SIZE_2D = 50
        MAX_NUMPOOL_2D = 999
        MAX_FILTERS_2D = 480
    
        use_this_for_batch_size_computation_2D = 19739648
        use_this_for_batch_size_computation_3D = 520000000  # 505789440
    
        def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2,
                     feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
                     norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
                     dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
                     nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False,
                     final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
                     conv_kernel_sizes=None,
                     upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
                     max_num_features=None, basic_block=ConvDropoutNormNonlin,
                     seg_output_use_bias=False):
            """
            basically more flexible than v1, architecture is the same
    
            Does this look complicated? Nah bro. Functionality > usability
    
            This does everything you need, including world peace.
    
            Questions? -> f.isensee@dkfz.de
            """
            super(Generic_UNet, self).__init__()
            self.convolutional_upsampling = convolutional_upsampling
            self.convolutional_pooling = convolutional_pooling
            self.upscale_logits = upscale_logits
            if nonlin_kwargs is None:
                nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
            if dropout_op_kwargs is None:
                dropout_op_kwargs = {'p': 0.5, 'inplace': True}
            if norm_op_kwargs is None:
                norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
    
            self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
    
            self.nonlin = nonlin
            self.nonlin_kwargs = nonlin_kwargs
            self.dropout_op_kwargs = dropout_op_kwargs
            self.norm_op_kwargs = norm_op_kwargs
            self.weightInitializer = weightInitializer
            self.conv_op = conv_op
            self.norm_op = norm_op
            self.dropout_op = dropout_op
            self.num_classes = num_classes
            self.final_nonlin = final_nonlin
            self._deep_supervision = deep_supervision
            self.do_ds = deep_supervision
    
            if conv_op == nn.Conv2d:
                upsample_mode = 'bilinear'
                pool_op = nn.MaxPool2d
                transpconv = nn.ConvTranspose2d
                if pool_op_kernel_sizes is None:
                    pool_op_kernel_sizes = [(2, 2)] * num_pool
                if conv_kernel_sizes is None:
                    conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
            elif conv_op == nn.Conv3d:
                upsample_mode = 'trilinear'
                pool_op = nn.MaxPool3d
                transpconv = nn.ConvTranspose3d
                if pool_op_kernel_sizes is None:
                    pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
                if conv_kernel_sizes is None:
                    conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
            else:
                raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
    
            self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
            self.pool_op_kernel_sizes = pool_op_kernel_sizes
            self.conv_kernel_sizes = conv_kernel_sizes
    
            self.conv_pad_sizes = []
            for krnl in self.conv_kernel_sizes:
                self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
    
            if max_num_features is None:
                if self.conv_op == nn.Conv3d:
                    self.max_num_features = self.MAX_NUM_FILTERS_3D
                else:
                    self.max_num_features = self.MAX_FILTERS_2D
            else:
                self.max_num_features = max_num_features
    
            self.conv_blocks_context = []
            self.conv_blocks_localization = []
            self.td = []
            self.tu = []
            self.seg_outputs = []
    
            output_features = base_num_features
            input_features = input_channels
    
            for d in range(num_pool):
                # determine the first stride
                if d != 0 and self.convolutional_pooling:
                    first_stride = pool_op_kernel_sizes[d - 1]
                else:
                    first_stride = None
    
                self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
                self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
                # add convolutions
                self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
                                                                  self.conv_op, self.conv_kwargs, self.norm_op,
                                                                  self.norm_op_kwargs, self.dropout_op,
                                                                  self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
                                                                  first_stride, basic_block=basic_block))
                if not self.convolutional_pooling:
                    self.td.append(pool_op(pool_op_kernel_sizes[d]))
                input_features = output_features
                output_features = int(np.round(output_features * feat_map_mul_on_downscale))
    
                output_features = min(output_features, self.max_num_features)
    
            # now the bottleneck.
            # determine the first stride
            if self.convolutional_pooling:
                first_stride = pool_op_kernel_sizes[-1]
            else:
                first_stride = None
    
            # the output of the last conv must match the number of features from the skip connection if we are not using
            # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
            # done by the transposed conv
            if self.convolutional_upsampling:
                final_num_features = output_features
            else:
                final_num_features = self.conv_blocks_context[-1].output_channels
    
            self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
            self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
            self.conv_blocks_context.append(nn.Sequential(
                StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
                                  self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
                                  self.nonlin_kwargs, first_stride, basic_block=basic_block),
                StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
                                  self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
                                  self.nonlin_kwargs, basic_block=basic_block)))
    
            # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
            if not dropout_in_localization:
                old_dropout_p = self.dropout_op_kwargs['p']
                self.dropout_op_kwargs['p'] = 0.0
    
            # now lets build the localization pathway
            for u in range(num_pool):
                nfeatures_from_down = final_num_features
                nfeatures_from_skip = self.conv_blocks_context[
                    -(2 + u)].output_channels  # self.conv_blocks_context[-1] is bottleneck, so start with -2
                n_features_after_tu_and_concat = nfeatures_from_skip * 2
    
                # the first conv reduces the number of features to match those of skip
                # the following convs work on that number of features
                # if not convolutional upsampling then the final conv reduces the num of features again
                if u != num_pool - 1 and not self.convolutional_upsampling:
                    final_num_features = self.conv_blocks_context[-(3 + u)].output_channels
                else:
                    final_num_features = nfeatures_from_skip
    
                if not self.convolutional_upsampling:
                    self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode))
                else:
                    self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)],
                                              pool_op_kernel_sizes[-(u + 1)], bias=False))
    
                self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)]
                self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)]
                self.conv_blocks_localization.append(nn.Sequential(
                    StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
                                      self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
                                      self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block),
                    StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
                                      self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
                                      self.nonlin, self.nonlin_kwargs, basic_block=basic_block)
                ))
    
            for ds in range(len(self.conv_blocks_localization)):
                self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes,
                                                1, 1, 0, 1, 1, seg_output_use_bias))
    
            self.upscale_logits_ops = []
            cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1]
            for usl in range(num_pool - 1):
                if self.upscale_logits:
                    self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]),
                                                            mode=upsample_mode))
                else:
                    self.upscale_logits_ops.append(lambda x: x)
    
            if not dropout_in_localization:
                self.dropout_op_kwargs['p'] = old_dropout_p
    
            # register all modules properly
            self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
            self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
            self.td = nn.ModuleList(self.td)
            self.tu = nn.ModuleList(self.tu)
            self.seg_outputs = nn.ModuleList(self.seg_outputs)
            if self.upscale_logits:
                self.upscale_logits_ops = nn.ModuleList(
                    self.upscale_logits_ops)  # lambda x:x is not a Module so we need to distinguish here
    
            if self.weightInitializer is not None:
                self.apply(self.weightInitializer)
                # self.apply(print_module_training_status)
    
        def forward(self, x):
            skips = []
            seg_outputs = []
            for d in range(len(self.conv_blocks_context) - 1):
                x = self.conv_blocks_context[d](x)
                skips.append(x)
                if not self.convolutional_pooling:
                    x = self.td[d](x)
    
            x = self.conv_blocks_context[-1](x)
    
            for u in range(len(self.tu)):
                x = self.tu[u](x)
                x = torch.cat((x, skips[-(u + 1)]), dim=1)
                x = self.conv_blocks_localization[u](x)
                seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x)))
    
            if self._deep_supervision and self.do_ds:
                return tuple([seg_outputs[-1]] + [i(j) for i, j in
                                                  zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])])
            else:
                return seg_outputs[-1]
    

    哈哈哈上面有句注释给我整笑了,“This does everything you need, including world peace.”,官方整活。

    接着上一步创建训练类,回到nnunet/run/run_training.py

    训练类初始化

    trainer.initialize(not validation_only)  # 见 nnUNetTrainerV2.initialize()
    

    7.训练过程

    万事俱备,开始训练

    trainer.run_training()  # 见 nnUNetTrainerV2.run_training()
    

    nnUNetTrainerV2.run_training()函数继承自nnUNetTrainer.run_training()

    nnUNetTrainer.run_training()又继承自NetworkTrainer.run_training()

        # nnUNetTrainerV2
        def run_training(self):
            """
            if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
            continued epoch with self.initial_lr
    
            we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
            :return:
            """
            self.maybe_update_lr(self.epoch)  # if we dont overwrite epoch then self.epoch+1 is used which is not what we
            # want at the start of the training
            ds = self.network.do_ds
            self.network.do_ds = True
            ret = super().run_training()
            self.network.do_ds = ds
            return ret
        
        # nnUNetTrainer
        def run_training(self):
        	self.save_debug_information()
            super(nnUNetTrainer, self).run_training()
        
        # NetworkTrainer
        def run_training(self):
            if not torch.cuda.is_available():
                self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
    
            _ = self.tr_gen.next()
            _ = self.val_gen.next()
    
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
            self._maybe_init_amp()
    
            maybe_mkdir_p(self.output_folder)        
            self.plot_network_architecture()
    
            if cudnn.benchmark and cudnn.deterministic:
                warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
                     "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
                     "If you want deterministic then set benchmark=False")
    
            if not self.was_initialized:
                self.initialize(True)
    
            while self.epoch < self.max_num_epochs:
                self.print_to_log_file("\nepoch: ", self.epoch)
                epoch_start_time = time()
                train_losses_epoch = []
    
                # train one epoch
                self.network.train()
    
                if self.use_progress_bar:
                    with trange(self.num_batches_per_epoch) as tbar:
                        for b in tbar:
                            tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
    
                            l = self.run_iteration(self.tr_gen, True)
    
                            tbar.set_postfix(loss=l)
                            train_losses_epoch.append(l)
                else:
                    for _ in range(self.num_batches_per_epoch):
                        l = self.run_iteration(self.tr_gen, True)
                        train_losses_epoch.append(l)
    
                self.all_tr_losses.append(np.mean(train_losses_epoch))
                self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
    
                with torch.no_grad():
                    # validation with train=False
                    self.network.eval()
                    val_losses = []
                    for b in range(self.num_val_batches_per_epoch):
                        l = self.run_iteration(self.val_gen, False, True)
                        val_losses.append(l)
                    self.all_val_losses.append(np.mean(val_losses))
                    self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
    
                    if self.also_val_in_tr_mode:
                        self.network.train()
                        # validation with train=True
                        val_losses = []
                        for b in range(self.num_val_batches_per_epoch):
                            l = self.run_iteration(self.val_gen, False)
                            val_losses.append(l)
                        self.all_val_losses_tr_mode.append(np.mean(val_losses))
                        self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
    
                self.update_train_loss_MA()  # needed for lr scheduler and stopping of training
    
                continue_training = self.on_epoch_end()
    
                epoch_end_time = time()
    
                if not continue_training:
                    # allows for early stopping
                    break
    
                self.epoch += 1
                self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
    
            self.epoch -= 1  # if we don't do this we can get a problem with loading model_final_checkpoint.
    
            if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
            # now we can delete latest as it will be identical with final
            if isfile(join(self.output_folder, "model_latest.model")):
                os.remove(join(self.output_folder, "model_latest.model"))
            if isfile(join(self.output_folder, "model_latest.model.pkl")):
                os.remove(join(self.output_folder, "model_latest.model.pkl"))
    

    每个batch进行迭代

    for _ in range(self.num_batches_per_epoch):
        l = self.run_iteration(self.tr_gen, True)
        train_losses_epoch.append(l)      
    

    单次迭代

    run_iteration

        def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
            data_dict = next(data_generator)
            data = data_dict['data']
            target = data_dict['target']
    
            data = maybe_to_torch(data)
            target = maybe_to_torch(target)
    
            if torch.cuda.is_available():
                data = to_cuda(data)
                target = to_cuda(target)
    
            self.optimizer.zero_grad()
    
            if self.fp16:
                with autocast():
                    output = self.network(data)
                    del data
                    l = self.loss(output, target)
    
                if do_backprop:
                    self.amp_grad_scaler.scale(l).backward()
                    self.amp_grad_scaler.step(self.optimizer)
                    self.amp_grad_scaler.update()
            else:
                output = self.network(data)
                del data
                l = self.loss(output, target)
    
                if do_backprop:
                    l.backward()
                    self.optimizer.step()
    
            if run_online_evaluation:
                self.run_online_evaluation(output, target)
    
            del target
    
            return l.detach().cpu().numpy()
    

    放假第一天效率有点低,就先撸到这里了。


    nnUnet最经典的就是数据预处理、数据增强以及数据后处理部分,目前的想法是用自己的模型替换nnUnet的初始模型,先抄作业试试。

    数据处理部分还在看源码学习中,我觉得要是玩通了nnUnet的数据处理方法,三维图像分割就算是登堂入室了。

    来源:宁眸

    物联沃分享整理
    物联沃-IOTWORD物联网 » nnUnet代码解读–模型训练

    发表评论