Source code for ivadomed.testing

import copy
from pathlib import Path
import nibabel as nib
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from loguru import logger
from import DataLoader, ConcatDataset
from tqdm import tqdm
from pathlib import Path

from ivadomed import metrics as imed_metrics
from ivadomed import utils as imed_utils
from ivadomed import visualize as imed_visualize
from ivadomed import inference as imed_inference
from ivadomed import uncertainty as imed_uncertainty
from ivadomed.loader import utils as imed_loader_utils
from ivadomed.object_detection import utils as imed_obj_detect
from import store_film_params, save_film_params
from import get_metadata
from ivadomed.postprocessing import threshold_predictions
from ivadomed.keywords import ConfigKW, ModelParamsKW, MetadataKW

cudnn.benchmark = True

[docs]def test(model_params, dataset_test, testing_params, path_output, device, cuda_available=True, metric_fns=None, postprocessing=None): """Main command to test the network. Args: model_params (dict): Model's parameters. dataset_test (imed_loader): Testing dataset. testing_params (dict): Testing parameters. path_output (str): Folder where predictions are saved. device (torch.device): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. postprocessing (dict): Contains postprocessing steps. Returns: dict: result metrics. """ # DATA LOADER test_loader = DataLoader(dataset_test, batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # LOAD TRAIN MODEL fname_model = Path(path_output, "")'Loading model: {}'.format(fname_model)) model = torch.load(fname_model, map_location=device) if cuda_available: model.cuda() model.eval() # CREATE OUTPUT FOLDER path_3Dpred = Path(path_output, 'pred_masks') if not path_3Dpred.is_dir(): path_3Dpred.mkdir(parents=True) # METRIC MANAGER metric_mgr = imed_metrics.MetricManager(metric_fns) # UNCERTAINTY SETTINGS if (testing_params['uncertainty']['epistemic'] or testing_params['uncertainty']['aleatoric']) and \ testing_params['uncertainty']['n_it'] > 0: n_monteCarlo = testing_params['uncertainty']['n_it'] + 1 testing_params['uncertainty']['applied'] = True'Computing model uncertainty over {} iterations.'.format(n_monteCarlo - 1)) else: testing_params['uncertainty']['applied'] = False n_monteCarlo = 1 for i_monteCarlo in range(n_monteCarlo): preds_npy, gt_npy = run_inference(test_loader, model, model_params, testing_params, str(path_3Dpred), cuda_available, i_monteCarlo, postprocessing) metric_mgr(preds_npy, gt_npy) # If uncertainty computation, don't apply it on last iteration for prediction if testing_params['uncertainty']['applied'] and (n_monteCarlo - 2 == i_monteCarlo): testing_params['uncertainty']['applied'] = False # COMPUTE UNCERTAINTY MAPS imed_uncertainty.run_uncertainty(image_folder=str(path_3Dpred)) metrics_dict = metric_mgr.get_results() metric_mgr.reset() return metrics_dict
[docs]def run_inference(test_loader, model, model_params, testing_params, ofolder, cuda_available, i_monte_carlo=None, postprocessing=None): """Run inference on the test data and save results as nibabel files. Args: test_loader (torch DataLoader): model (nn.Module): model_params (dict): testing_params (dict): ofolder (str): Folder where predictions are saved. cuda_available (bool): If True, CUDA is available. i_monte_carlo (int): i_th Monte Carlo iteration. postprocessing (dict): Indicates postprocessing steps. Returns: ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d. """ # INIT STORAGE VARIABLES preds_npy_list, gt_npy_list, filenames = [], [], [] pred_tmp_lst, z_tmp_lst = [], [] image = None volume = None weight_matrix = None # Create dict containing gammas and betas after each FiLM layer. if ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS]): # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one. gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} metadata_values_lst = [] for i, batch in enumerate(tqdm(test_loader, desc="Inference - Iteration " + str(i_monte_carlo))): with torch.no_grad(): # GET SAMPLES # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth # gt_samples: idem with n_labels # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda(imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # EPISTEMIC UNCERTAINTY if testing_params['uncertainty']['applied'] and testing_params['uncertainty']['epistemic']: for m in model.modules(): if m.__class__.__name__.startswith('Dropout'): m.train() # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: # Reconstruct image with only one modality input_samples = batch['input'][0] if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET and model_params[ModelParamsKW.ATTENTION] and ofolder: imed_visualize.save_feature_map(batch, "attentionblock2", str(Path(ofolder).parent), model, input_samples, slice_axis=test_loader.dataset.slice_axis) if ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS]): # Store the values of gammas and betas after the last epoch for each batch gammas_dict, betas_dict, metadata_values_lst = store_film_params(gammas_dict, betas_dict, metadata_values_lst, batch[MetadataKW.INPUT_METADATA], model, model_params[ModelParamsKW.FILM_LAYERS], model_params[ModelParamsKW.DEPTH], model_params[ModelParamsKW.METADATA]) # PREDS TO CPU preds_cpu = preds.cpu() task = imed_utils.get_task(model_params[ModelParamsKW.NAME]) if task == "classification": gt_npy_list.append(gt_samples.cpu().numpy()) preds_npy_list.append( # RECONSTRUCT 3D IMAGE last_batch_bool = (i == len(test_loader) - 1) slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']] # LOOP ACROSS SAMPLES for smp_idx in range(len(preds_cpu)): if "bounding_box" in batch[MetadataKW.INPUT_METADATA][smp_idx][0]: imed_obj_detect.adjust_undo_transforms(testing_params["undo_transforms"].transforms, batch, smp_idx) if model_params[ModelParamsKW.IS_2D]: preds_idx_arr = None idx_slice = batch[MetadataKW.INPUT_METADATA][smp_idx][0]['slice_index'] n_slices = batch[MetadataKW.INPUT_METADATA][smp_idx][0]['data_shape'][-1] last_slice_bool = (idx_slice + 1 == n_slices) last_sample_bool = (last_batch_bool and smp_idx == len(preds_cpu) - 1) length_2D = model_params[ModelParamsKW.LENGTH_2D] if ModelParamsKW.LENGTH_2D in model_params else [] stride_2D = model_params[ModelParamsKW.STRIDE_2D] if ModelParamsKW.STRIDE_2D in model_params else [] if length_2D: # undo transformations for patch and reconstruct slice preds_idx_undo, metadata_idx, last_patch_bool, image, weight_matrix = \ imed_inference.image_reconstruction(batch, preds_cpu, testing_params['undo_transforms'], smp_idx, image, weight_matrix) else: # Set last_patch_bool to True (only one patch per slice) last_patch_bool = True # undo transformations for slice preds_idx_undo, metadata_idx = testing_params["undo_transforms"](preds_cpu[smp_idx], batch['gt_metadata'][smp_idx], data_type='gt') if last_patch_bool: # preds_idx_undo is a list n_label arrays preds_idx_arr = np.array(preds_idx_undo) # TODO: gt_filenames should not be a list fname_ref = list(filter(None, metadata_idx[0][MetadataKW.GT_FILENAMES]))[0] if preds_idx_arr is not None: # add new sample to pred_tmp_lst, of size n_label X h X w ... pred_tmp_lst.append(preds_idx_arr) # TODO: slice_index should be stored in gt_metadata as well z_tmp_lst.append(int(idx_slice)) filenames = metadata_idx[0][MetadataKW.GT_FILENAMES] # NEW COMPLETE VOLUME if (pred_tmp_lst and ((last_patch_bool and last_slice_bool) or last_sample_bool) and task != "classification"): # save the completely processed file as a NifTI file if ofolder: fname_pred = str(Path(ofolder, Path(fname_ref).name)) fname_pred = fname_pred.split(testing_params['target_suffix'][0])[0] + '_pred.nii.gz' # If Uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split('.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(2) + '.nii.gz' postprocessing = None else: fname_pred = None output_nii = imed_inference.pred_to_nib(data_lst=pred_tmp_lst, z_lst=z_tmp_lst, fname_ref=fname_ref, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='2d', bin_thr=-1, postprocessing=postprocessing) output_data = output_nii.get_fdata().transpose(3, 0, 1, 2) preds_npy_list.append(output_data) gt = get_gt(filenames) gt_npy_list.append(gt) output_nii_shape = output_nii.get_fdata().shape if len(output_nii_shape) == 4 and output_nii_shape[-1] > 1 and ofolder: logger.warning('No color labels saved due to a temporary issue. For more details see:' '') # TODO: put back the code below. See #720 # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1), # False, # fname_ref, # fname_pred.split(".nii.gz")[0] + '_color.nii.gz', # imed_utils.AXIS_DCT[testing_params['slice_axis']]) # For Microscopy PNG/TIF files (TODO: implement OMETIFF behavior) extension = imed_loader_utils.get_file_extension(fname_ref) if "nii" not in extension and fname_pred: output_list = imed_inference.split_classes(output_nii) # Reformat target list to include class index and be compatible with multiple raters target_list = ["_class-%d" % i for i in range(len(testing_params['target_suffix']))] imed_inference.pred_to_png(output_list, target_list, fname_pred.split("_pred.nii.gz")[0], suffix="_pred.png") # re-init pred_stack_lst and last_slice_bool pred_tmp_lst, z_tmp_lst = [], [] last_slice_bool = False else: pred_undo, metadata, last_sample_bool, volume, weight_matrix = \ imed_inference.volume_reconstruction(batch, preds_cpu, testing_params['undo_transforms'], smp_idx, volume, weight_matrix) # Indicator of last batch if last_sample_bool: pred_undo = np.array(pred_undo) fname_ref = metadata[0][MetadataKW.GT_FILENAMES][0] if ofolder: fname_pred = str(Path(ofolder, Path(fname_ref).name)) fname_pred = fname_pred.split(testing_params['target_suffix'][0])[0] + '_pred.nii.gz' # If uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split('.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(2) + '.nii.gz' postprocessing = None else: fname_pred = None # Choose only one modality output_nii = imed_inference.pred_to_nib(data_lst=[pred_undo], z_lst=[], fname_ref=fname_ref, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='3d', bin_thr=-1, postprocessing=postprocessing) output_data = output_nii.get_fdata().transpose(3, 0, 1, 2) preds_npy_list.append(output_data) gt = get_gt(metadata[0][MetadataKW.GT_FILENAMES]) gt_npy_list.append(gt) # Save merged labels with color if pred_undo.shape[0] > 1 and ofolder: logger.warning('No color labels saved due to a temporary issue. For more details see:' '') # TODO: put back the code below. See #720 # imed_visualize.save_color_labels(pred_undo, # False, # batch[MetadataKW.INPUT_METADATA][smp_idx][0]['input_filenames'], # fname_pred.split(".nii.gz")[0] + '_color.nii.gz', # slice_axis) if ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS]): save_film_params(gammas_dict, betas_dict, metadata_values_lst, model_params[ModelParamsKW.DEPTH], ofolder.replace("pred_masks", "")) return preds_npy_list, gt_npy_list
[docs]def threshold_analysis(model_path, ds_lst, model_params, testing_params, metric="dice", increment=0.1, fname_out="thr.png", cuda_available=True): """Run a threshold analysis to find the optimal threshold on a sub-dataset. Args: model_path (str): Model path. ds_lst (list): List of loaders. model_params (dict): Model's parameters. testing_params (dict): Testing parameters metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis is performed. increment (float): Increment between tested thresholds. fname_out (str): Plot output filename. cuda_available (bool): If True, CUDA is available. Returns: float: optimal threshold. """ if metric not in ["dice", "recall_specificity"]: raise ValueError('\nChoice of metric for threshold analysis: dice, recall_specificity.') # Adjust some testing parameters testing_params["uncertainty"]["applied"] = False # Load model model = torch.load(model_path) # Eval mode model.eval() # List of thresholds thr_list = list(np.arange(0.0, 1.0, increment))[1:] # Init metric manager for each thr metric_fns = [imed_metrics.recall_score, imed_metrics.dice_score, imed_metrics.specificity_score] metric_dict = {thr: imed_metrics.MetricManager(metric_fns) for thr in thr_list} # Load loader = DataLoader(ConcatDataset(ds_lst), batch_size=testing_params["batch_size"], shuffle=False, pin_memory=True, sampler=None, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Run inference preds_npy, gt_npy = run_inference(loader, model, model_params, testing_params, ofolder=None, cuda_available=cuda_available)'Running threshold analysis to find optimal threshold') # Make sure the GT is binarized gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy] # Move threshold for thr in tqdm(thr_list, desc="Search"): preds_thr = [threshold_predictions(copy.deepcopy(pred), thr=thr) for pred in preds_npy] metric_dict[thr](preds_thr, gt_npy) # Get results tpr_list, fpr_list, dice_list = [], [], [] for thr in thr_list: result_thr = metric_dict[thr].get_results() tpr_list.append(result_thr["recall_score"]) fpr_list.append(1 - result_thr["specificity_score"]) dice_list.append(result_thr["dice_score"]) # Get optimal threshold if metric == "dice": diff_list = dice_list else: diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)] optimal_idx = np.max(np.where(diff_list == np.max(diff_list))) optimal_threshold = thr_list[optimal_idx]'\tOptimal threshold: {}'.format(optimal_threshold)) # Save plot'\tSaving plot: {}'.format(fname_out)) if metric == "dice": # Run plot imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out) else: # Add 0 and 1 as extrema tpr_list = [0.0] + tpr_list + [1.0] fpr_list = [0.0] + fpr_list + [1.0] optimal_idx += 1 # Run plot imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out) return optimal_threshold
[docs]def get_gt(filenames): """Get ground truth data as numpy array. Args: filenames (list): List of ground truth filenames, one per class. Returns: ndarray: 4D numpy array. """ # Check filenames extentions and update paths if not NifTI filenames = [imed_loader_utils.update_filename_to_nifti(fname) for fname in filenames] gt_lst = [] for gt in filenames: # For multi-label, if all labels are not in every image if gt is not None: gt_lst.append(nib.load(gt).get_fdata()) else: gt_lst.append(np.zeros(nib.load(list(filter(None, filenames))[0]).get_fdata().shape)) return np.array(gt_lst)