Source code for ivadomed.main

import json
import argparse
import copy
import joblib
import torch.backends.cudnn as cudnn
import nibabel as nib
import sys
import platform
import multiprocessing
import re

from ivadomed.loader.bids_dataframe import BidsDataframe
from ivadomed import evaluation as imed_evaluation
from ivadomed import config_manager as imed_config_manager
from ivadomed import testing as imed_testing
from ivadomed import training as imed_training
from ivadomed import transforms as imed_transforms
from ivadomed import utils as imed_utils
from ivadomed import metrics as imed_metrics
from ivadomed import inference as imed_inference
from ivadomed.loader import utils as imed_loader_utils, loader as imed_loader, film as imed_film
from loguru import logger
from pathlib import Path

cudnn.benchmark = True

# List of not-default available models i.e. different from Unet
MODEL_LIST = ['Modified3DUNet', 'HeMISUnet', 'FiLMedUnet', 'resnet18', 'densenet121', 'Countception']

def get_parser():
    parser = argparse.ArgumentParser(add_help=False)

    command_group = parser.add_mutually_exclusive_group(required=False)

    command_group.add_argument("--train", dest='train', action='store_true',
                               help="Perform training on data.")
    command_group.add_argument("--test", dest='test', action='store_true',
                               help="Perform testing on trained model.")
    command_group.add_argument("--segment", dest='segment', action='store_true',
                               help="Perform segmentation on data.")

    parser.add_argument("-c", "--config", required=True, type=str,
                        help="Path to configuration file.")

    optional_args = parser.add_argument_group('OPTIONAL ARGUMENTS')

    optional_args.add_argument("-pd", "--path-data", dest="path_data", required=False, type=str,
                               nargs="*", help="""Path to data in BIDs format. You may list one
                               or more paths; separate each path with a space, e.g.
                               --path-data some/path/a some/path/b""")
    optional_args.add_argument("-po", "--path-output", required=False, type=str, dest="path_output",
                               help="Path to output directory.")
    optional_args.add_argument('-g', '--gif', required=False, type=int, default=0,
                               help='Number of GIF files to output. Each GIF file corresponds to a 2D slice showing the '
                                    'prediction over epochs (one frame per epoch). The prediction is run on the '
                                    'validation dataset. GIF files are saved in the output path.')
    optional_args.add_argument('-t', '--thr-increment', dest="thr_increment", required=False, type=float,
                               help='A threshold analysis is performed at the end of the training using the trained '
                                    'model and the training+validation sub-datasets to find the optimal binarization '
                                    'threshold. The specified value indicates the increment between 0 and 1 used during '
                                    'the analysis (e.g. 0.1). Plot is saved under "[PATH_OUTPUT]/thr.png" and the '
                                    'optimal threshold in "[PATH_OUTPUT]/config_file.json as "binarize_prediction" '
    optional_args.add_argument('--resume-training', dest="resume_training", required=False, action='store_true',
                               help='Load a saved model ("checkpoint.pth.tar" in the output directory specified either with flag "--path-output" or via the config file "output_path" argument)  '
                                    'for resume training. This training state is saved everytime a new best model is saved in the output directory specified with flag "--path-output"')
    optional_args.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS,
                               help='Shows function documentation.')

    return parser

def create_path_model(context, model_params, ds_train, path_output, train_onehotencoder):
    path_model = Path(path_output, context.get("model_name"))
    if not path_model.is_dir():'Creating model directory: {path_model}')
        if 'film_layers' in model_params and any(model_params['film_layers']):
            joblib.dump(train_onehotencoder, path_model.joinpath("one_hot_encoder.joblib"))
            if 'metadata_dict' in ds_train[0]['input_metadata'][0]:
                metadata_dict = ds_train[0]['input_metadata'][0]['metadata_dict']
                joblib.dump(metadata_dict, path_model.joinpath("metadata_dict.joblib"))

    else:'Model directory already exists: {path_model}')

def check_multiple_raters(is_train, loader_params):
    if any([isinstance(class_suffix, list) for class_suffix in loader_params["target_suffix"]]):
            "Annotations from multiple raters will be used during model training, one annotation from one rater "
            "randomly selected at each iteration.\n")
        if not is_train:
                "Please provide only one annotation per class in 'target_suffix' when not training a model.\n")

def film_normalize_data(context, model_params, ds_train, ds_valid, path_output):
    # Normalize metadata before sending to the FiLM network
    results = imed_film.get_film_metadata_models(ds_train=ds_train,
    ds_train, train_onehotencoder, metadata_clustering_models = results
    ds_valid = imed_film.normalize_metadata(ds_valid, metadata_clustering_models, context["debugging"],
    model_params.update({"film_onehotencoder": train_onehotencoder,
                         "n_metadata": len([ll for l in train_onehotencoder.categories_ for ll in l])})
    joblib.dump(metadata_clustering_models, Path(path_output, "clustering_models.joblib"))
    joblib.dump(train_onehotencoder, Path(path_output + "one_hot_encoder.joblib"))

    return model_params, ds_train, ds_valid, train_onehotencoder

def get_dataset(bids_df, loader_params, data_lst, transform_params, cuda_available, device, ds_type):
    ds = imed_loader.load_dataset(bids_df, **{**loader_params, **{'data_list': data_lst,
                                                                  'transforms_params': transform_params,
                                                                  'dataset_type': ds_type}}, device=device,
    return ds

def save_config_file(context, path_output):
    # Save config file within path_output and path_output/model_name
    # Done after the threshold_analysis to propate this info in the config files
    with Path(path_output, "config_file.json").open(mode='w') as fp:
        json.dump(context, fp, indent=4)
    with Path(path_output, context.get("model_name"), context.get("model_name") + ".json").open(mode='w') as fp:
        json.dump(context, fp, indent=4)

def set_loader_params(context, is_train):
    loader_params = copy.deepcopy(context["loader_parameters"])
    if is_train:
        loader_params["contrast_params"]["contrast_lst"] = loader_params["contrast_params"]["training_validation"]
        loader_params["contrast_params"]["contrast_lst"] = loader_params["contrast_params"]["testing"]
    if "FiLMedUnet" in context and context["FiLMedUnet"]["applied"]:
        loader_params.update({"metadata_type": context["FiLMedUnet"]["metadata"]})

    # Load metadata necessary to balance the loader
    if context['training_parameters']['balance_samples']['applied'] and \
            context['training_parameters']['balance_samples']['type'] != 'gt':
        loader_params.update({"metadata_type": context['training_parameters']['balance_samples']['type']})
    return loader_params

def set_model_params(context, loader_params):
    model_params = copy.deepcopy(context["default_model"])
    model_params["folder_name"] = copy.deepcopy(context["model_name"])
    model_context_list = [model_name for model_name in MODEL_LIST
                          if model_name in context and context[model_name]["applied"]]
    if len(model_context_list) == 1:
        model_params["name"] = model_context_list[0]
    elif 'Modified3DUNet' in model_context_list and 'FiLMedUnet' in model_context_list and len(model_context_list) == 2:
        model_params["name"] = 'Modified3DUNet'
        for i in range(len(model_context_list)):
    elif len(model_context_list) > 1:
        logger.error('ERROR: Several models are selected in the configuration file: {}.'
              'Please select only one (i.e. only one where: "applied": true).'.format(model_context_list))

    model_params['is_2d'] = False if "Modified3DUNet" in model_params['name'] else model_params['is_2d']
    # Get in_channel from contrast_lst
    if loader_params["multichannel"]:
        model_params["in_channel"] = len(loader_params["contrast_params"]["contrast_lst"])
        model_params["in_channel"] = 1
    # Get out_channel from target_suffix
    model_params["out_channel"] = len(loader_params["target_suffix"])
    # If multi-class output, then add background class
    if model_params["out_channel"] > 1:
        model_params.update({"out_channel": model_params["out_channel"] + 1})
    # Display for spec' check
    # Update loader params
    if 'object_detection_params' in context:
        object_detection_params = context['object_detection_params']
        object_detection_params.update({"gpu_ids": context['gpu_ids'][0],
                                        "path_output": context['path_output']})
        loader_params.update({"object_detection_params": object_detection_params})

    loader_params.update({"model_params": model_params})

    return model_params, loader_params

def set_output_path(context):
    path_output = copy.deepcopy(context["path_output"])
    if not Path(path_output).is_dir():'Creating output path: {}'.format(path_output))
    else:'Output path already exists: {}'.format(path_output))

    return path_output

def update_film_model_params(context, ds_test, model_params, path_output):
    clustering_path = Path(path_output, "clustering_models.joblib")
    metadata_clustering_models = joblib.load(clustering_path)
    # Model directory
    ohe_path = Path(path_output, context.get("model_name"), "one_hot_encoder.joblib")
    one_hot_encoder = joblib.load(ohe_path)
    ds_test = imed_film.normalize_metadata(ds_test, metadata_clustering_models, context.get("debugging"),
    model_params.update({"film_onehotencoder": one_hot_encoder,
                         "n_metadata": len([ll for l in one_hot_encoder.categories_ for ll in l])})

    return ds_test, model_params

def run_segment_command(context, model_params):
    # BIDSDataframe of all image files
    # Indexing of derivatives is False for command segment
    bids_df = BidsDataframe(context['loader_parameters'], context['path_output'], derivatives=False)

    # Append subjects filenames into a list
    bids_subjects = sorted(bids_df.df.get('filename').to_list())

    # Add postprocessing to packaged model
    path_model = Path(context.get('path_output'), context.get('model_name'))
    path_model_config = Path(path_model, context.get('model_name') + ".json")
    model_config = imed_config_manager.load_json(str(path_model_config))
    model_config['postprocessing'] = context.get('postprocessing')
    with'w') as fp:
        json.dump(model_config, fp, indent=4)
    options = {}

    # Initialize a list of already seen subject ids for multichannel
    seen_subj_ids = []

    for subject in bids_subjects:
        if context['loader_parameters']['multichannel']:
            # Get subject_id for multichannel
            df_sub = bids_df.df.loc[bids_df.df['filename'] == subject]
            subj_id = re.sub(r'_' + df_sub['suffix'].values[0] + '.*', '', subject)
            if subj_id not in seen_subj_ids:
                # if subj_id has not been seen yet
                fname_img = []
                provided_contrasts = []
                contrasts = context['loader_parameters']['contrast_params']['testing']
                # Keep contrast order
                for c in contrasts:
                    df_tmp = bids_df.df[
                        bids_df.df['filename'].str.contains(subj_id) & bids_df.df['suffix'].str.contains(c)]
                    if ~df_tmp.empty:
                if len(fname_img) != len(contrasts):
                    logger.warning("Missing contrast for subject {}. {} were provided but {} are required. Skipping "
                                   "subject.".format(subj_id, provided_contrasts, contrasts))
                # Returns an empty list for subj_id already seen
                fname_img = []
            fname_img = bids_df.df[bids_df.df['filename'] == subject]['path'].to_list()

        # Add film metadata to options for segment_volume
        if 'film_layers' in model_params and any(model_params['film_layers']) and model_params['metadata']:
            metadata = bids_df.df[bids_df.df['filename'] == subject][model_params['metadata']].values[0]
            options['metadata'] = metadata

        # Add microscopy pixel size metadata to options for segment_volume
        if 'PixelSize' in bids_df.df.columns:
            options['pixel_size'] = bids_df.df.loc[bids_df.df['filename'] == subject]['PixelSize'].values[0]

        if fname_img:
            pred_list, target_list = imed_inference.segment_volume(str(path_model),
            pred_path = Path(context.get('path_output'), "pred_masks")
            if not pred_path.exists():

            for pred, target in zip(pred_list, target_list):
                filename = subject.split('.')[0] + target + "_pred" + ".nii.gz"
      , Path(pred_path, filename))

            # For Microscopy PNG/TIF files (TODO: implement OMETIFF behavior)
            extension = imed_loader_utils.get_file_extension(subject)
            if "nii" not in extension:
                                           str(Path(pred_path, subject)).replace(extension, ''))

[docs]def run_command(context, n_gif=0, thr_increment=None, resume_training=False): """Run main command. This function is central in the ivadomed project as training / testing / evaluation commands are run via this function. All the process parameters are defined in the config. Args: context (dict): Dictionary containing all parameters that are needed for a given process. See :doc:`configuration_file` for more details. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the output path. thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and the training + validation sub-dataset to find the optimal binarization threshold. The specified value indicates the increment between 0 and 1 used during the ROC analysis (e.g. 0.1). resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the output directory specified with flag "--path-output" or via the config file "output_path" ' This training state is saved everytime a new best model is saved in the log argument) for resume training directory. Returns: float or pandas.DataFrame or None: * If "train" command: Returns floats: best loss score for both training and validation. * If "test" command: Returns a pandas Dataframe: of metrics computed for each subject of the testing sub-dataset and return the prediction metrics before evaluation. * If "segment" command: No return value. """ command = copy.deepcopy(context.get("command")) path_output = set_output_path(context) path_log = Path(context.get('path_output'), context.get('log_file')) logger.remove() logger.add(str(path_log)) logger.add(sys.stdout) # Create a log with the version of the Ivadomed software and the version of the Annexed dataset (if present) create_dataset_and_ivadomed_version_log(context) cuda_available, device = imed_utils.define_device(context['gpu_ids'][0]) # BACKWARDS COMPATIBILITY: If bids_path is string, assign to list - Do this here so it propagates to all functions context['loader_parameters']['path_data'] = imed_utils.format_path_data(context['loader_parameters']['path_data']) # Loader params loader_params = set_loader_params(context, command == "train") # Get transforms for each subdataset transform_train_params, transform_valid_params, transform_test_params = \ imed_transforms.get_subdatasets_transforms(context["transformation"]) # MODEL PARAMETERS model_params, loader_params = set_model_params(context, loader_params) if command == 'segment': run_segment_command(context, model_params) return # BIDSDataframe of all image files # Indexing of derivatives is True for command train and test bids_df = BidsDataframe(loader_params, path_output, derivatives=True) # Get subject filenames lists. "segment" command uses all participants of data path, hence no need to split train_lst, valid_lst, test_lst = imed_loader_utils.get_subdatasets_subject_files_list(context["split_dataset"], bids_df.df, path_output, context["loader_parameters"] ['subject_selection']) # Generating sha256 for the training files imed_utils.generate_sha_256(context, bids_df.df, train_lst) # TESTING PARAMS # Aleatoric uncertainty if context['uncertainty']['aleatoric'] and context['uncertainty']['n_it'] > 0: transformation_dict = transform_train_params else: transformation_dict = transform_test_params undo_transforms = imed_transforms.UndoCompose(imed_transforms.Compose(transformation_dict, requires_undo=True)) testing_params = copy.deepcopy(context["training_parameters"]) testing_params.update({'uncertainty': context["uncertainty"]}) testing_params.update({'target_suffix': loader_params["target_suffix"], 'undo_transforms': undo_transforms, 'slice_axis': loader_params['slice_axis']}) if command == "train": imed_utils.display_selected_transfoms(transform_train_params, dataset_type=["training"]) imed_utils.display_selected_transfoms(transform_valid_params, dataset_type=["validation"]) elif command == "test": imed_utils.display_selected_transfoms(transformation_dict, dataset_type=["testing"]) # Check if multiple raters check_multiple_raters(command == "train", loader_params) if command == 'train': # Get Validation dataset ds_valid = get_dataset(bids_df, loader_params, valid_lst, transform_valid_params, cuda_available, device, 'validation') # Get Training dataset ds_train = get_dataset(bids_df, loader_params, train_lst, transform_train_params, cuda_available, device, 'training') metric_fns = imed_metrics.get_metric_fns(ds_train.task) # If FiLM, normalize data if 'film_layers' in model_params and any(model_params['film_layers']): model_params, ds_train, ds_valid, train_onehotencoder = \ film_normalize_data(context, model_params, ds_train, ds_valid, path_output) else: train_onehotencoder = None # Model directory create_path_model(context, model_params, ds_train, path_output, train_onehotencoder) save_config_file(context, path_output) # RUN TRAINING best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = imed_training.train( model_params=model_params, dataset_train=ds_train, dataset_val=ds_valid, training_params=context["training_parameters"], path_output=path_output, device=device, cuda_available=cuda_available, metric_fns=metric_fns, n_gif=n_gif, resume_training=resume_training, debugging=context["debugging"]) if thr_increment: # LOAD DATASET if command != 'train': # If command == train, then ds_valid already load # Get Validation dataset ds_valid = get_dataset(bids_df, loader_params, valid_lst, transform_valid_params, cuda_available, device, 'validation') # Get Training dataset with no Data Augmentation ds_train = get_dataset(bids_df, loader_params, train_lst, transform_valid_params, cuda_available, device, 'training') # Choice of optimisation metric metric = "recall_specificity" if model_params["name"] in imed_utils.CLASSIFIER_LIST else "dice" # Model path model_path = Path(path_output, "") # Run analysis thr = imed_testing.threshold_analysis(model_path=str(model_path), ds_lst=[ds_train, ds_valid], model_params=model_params, testing_params=testing_params, metric=metric, increment=thr_increment, fname_out=str(Path(path_output, "roc.png")), cuda_available=cuda_available) # Update threshold in config file context["postprocessing"]["binarize_prediction"] = {"thr": thr} save_config_file(context, path_output) if command == 'train': return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss if command == 'test': # LOAD DATASET # Warn user that the input-level dropout is set during inference if loader_params['is_input_dropout']: logger.warning("Input-level dropout is set during testing. To turn this option off, set 'is_input_dropout'" "to 'false' in the configuration file.") ds_test = imed_loader.load_dataset(bids_df, **{**loader_params, **{'data_list': test_lst, 'transforms_params': transformation_dict, 'dataset_type': 'testing', 'requires_undo': True}}, device=device, cuda_available=cuda_available) metric_fns = imed_metrics.get_metric_fns(ds_test.task) if 'film_layers' in model_params and any(model_params['film_layers']): ds_test, model_params = update_film_model_params(context, ds_test, model_params, path_output) # RUN INFERENCE pred_metrics = imed_testing.test(model_params=model_params, dataset_test=ds_test, testing_params=testing_params, path_output=path_output, device=device, cuda_available=cuda_available, metric_fns=metric_fns, postprocessing=context['postprocessing']) # RUN EVALUATION df_results = imed_evaluation.evaluate(bids_df, path_output=path_output, target_suffix=loader_params["target_suffix"], eval_params=context["evaluation_parameters"]) return df_results, pred_metrics
def create_dataset_and_ivadomed_version_log(context): path_data = context.get('loader_parameters').get('path_data') ivadomed_version = imed_utils._version_string() datasets_version = [] if isinstance(path_data, str): datasets_version = [imed_utils.__get_commit(path_to_git_folder=path_data)] elif isinstance(path_data, list): for Dataset in path_data: datasets_version.append(imed_utils.__get_commit(path_to_git_folder=Dataset)) path_log = Path(context.get('path_output'), 'version_info.log') try: f ="w") except OSError as err: logger.error(f"OS error: {err}") raise Exception("Have you selected a log folder, and do you have write permissions for that folder?") # IVADOMED f.write('IVADOMED TOOLBOX\n----------------\n(' + ivadomed_version + ')') # DATASETS path_data = imed_utils.format_path_data(path_data) f.write('\n\n\nDATASET VERSION\n---------------\n') f.write('The following BIDS dataset(s) were used for training.\n') for i_dataset in range(len(path_data)): if datasets_version[i_dataset] not in ['', '?!?']: f.write(str(i_dataset + 1) + '. ' + path_data[i_dataset] + ' - Dataset Annex version: ' + datasets_version[ i_dataset] + '\n') else: f.write(str(i_dataset + 1) + '. ' + path_data[i_dataset] + ' - Dataset is not Annexed.\n') # SYSTEM INFO f.write('\n\nSYSTEM INFO\n-------------\n') platform_running = sys.platform if platform_running.find('darwin') != -1: os_running = 'osx' elif platform_running.find('linux') != -1: os_running = 'linux' elif platform_running.find('win32') or platform_running.find('win64'): os_running = 'windows' else: os_running = 'NA' f.write('OS: ' + os_running + ' (' + platform.platform() + ')\n') # Display number of CPU cores f.write('CPU cores: Available: {}\n\n\n\n\n'.format(multiprocessing.cpu_count())) # USER INPUTS f.write('CONFIG INPUTS\n-------------\n') if sys.version_info[0] > 2: for k, v in context.items(): f.write(str(k) + ': ' + str(v) + '\n') # Making sure all numbers are converted to strings else: for k, v in context.viewitems(): # Python2 f.write(str(k) + ': ' + str(v) + '\n') f.close() def run_main(): imed_utils.init_ivadomed() parser = get_parser() args = parser.parse_args() # Get context from configuration file path_config_file = args.config context = imed_config_manager.ConfigurationManager(path_config_file).get_config() context["command"] = imed_utils.get_command(args, context) context["path_output"] = imed_utils.get_path_output(args, context) context["loader_parameters"]["path_data"] = imed_utils.get_path_data(args, context) # Run command run_command(context=context, n_gif=args.gif if args.gif is not None else 0, thr_increment=args.thr_increment if args.thr_increment else None, resume_training=bool(args.resume_training)) if __name__ == "__main__": run_main()