Source code for ivadomed.scripts.visualize_transforms

#!/usr/bin/env python

import argparse
import nibabel as nib
import numpy as np
import random
import torch

from pathlib import Path
from loguru import logger
from ivadomed import config_manager as imed_config_manager
from ivadomed.loader import utils as imed_loader_utils
from ivadomed.loader.sample_meta_data import SampleMetadata
from ivadomed import transforms as imed_transforms
from ivadomed import utils as imed_utils
from ivadomed import maths as imed_maths
from ivadomed.keywords import ConfigKW, TransformationKW, LoaderParamsKW, MetadataKW

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", required=True,
                        help="Input image filename.",
    parser.add_argument("-c", "--config", required=True,
                        help="Config filename.",
    parser.add_argument("-n", "--number", required=False, default=1,
                        help="Number of random slices to visualize.",
    parser.add_argument("-o", "--output", required=False, default="./",
                        help="Output folder.",
    parser.add_argument("-r", "--roi", required=False, metavar=imed_utils.Metavar.file,
                        help="ROI filename. Only required if ROICrop is part of the transformations.")
    return parser

def get_data(fname_in, axis):
    """Get data from fname along an axis.

         fname_in string: image fname
         axis int:

        nibabel, ndarray
    # Load image
    input_img = nib.load(fname_in)
    # Reorient as canonical
    input_img = nib.as_closest_canonical(input_img)
    # Get input data
    input_data = input_img.get_fdata(dtype=np.float32)
    # Reorient data
    input_data = imed_loader_utils.orient_img_hwd(input_data, slice_axis=axis)
    return input_img, input_data

[docs] def run_visualization(input, config, number, output, roi): """Utility function to visualize Data Augmentation transformations. Data augmentation is a key part of the Deep Learning training scheme. This script aims at facilitating the fine-tuning of data augmentation parameters. To do so, this script provides a step-by-step visualization of the transformations that are applied on data. This function applies a series of transformations (defined in a configuration file ``-c``) to ``-n`` 2D slices randomly extracted from an input image (``-i``), and save as png the resulting sample after each transform. For example:: ivadomed_visualize_transforms -i t2s.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz Provides a visualization of a series of three transformation on a randomly selected slice: .. image:: :width: 600px :align: center And on a binary mask:: ivadomed_visualize_transforms -i t2s_gmseg.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz Gives: .. image:: :width: 600px :align: center Args: input (string): Image filename. Flag: ``--input``, ``-i`` config (string): Configuration file filename. Flag: ``--config``, ``-c`` number (int): Number of slices randomly extracted. Flag: ``--number``, ``-n`` output (string): Folder path where the results are saved. Flag: ``--ofolder``, ``-o`` roi (string): Filename of the region of interest. Only needed if ROICrop is part of the transformations. Flag: ``--roi``, ``-r`` """ # Load context context = imed_config_manager.ConfigurationManager(config).get_config() # Create output folder if not Path(output).is_dir(): Path(output).mkdir(parents=True) # Slice extracted according to below axis axis = imed_utils.AXIS_DCT[context[ConfigKW.LOADER_PARAMETERS][LoaderParamsKW.SLICE_AXIS]] # Get data input_img, input_data = get_data(input, axis) # Image or Mask is_mask = np.array_equal(input_data, input_data.astype(bool)) # Get zooms zooms = imed_loader_utils.orient_shapes_hwd(input_img.header.get_zooms(), slice_axis=axis) # Get indexes indexes = random.sample(range(0, input_data.shape[2]), number) # Get training transforms training_transforms, _, _ = imed_transforms.get_subdatasets_transforms(context[ConfigKW.TRANSFORMATION]) if TransformationKW.ROICROP in training_transforms: if roi and Path(roi).is_file(): roi_img, roi_data = get_data(roi, axis) else: raise ValueError("\nPlease provide ROI image (-r) in order to apply ROICrop transformation.") # Compose transforms dict_transforms = {} stg_transforms = "" for transform_name in training_transforms: # We skip NumpyToTensor transform since that s only a change of data type if transform_name == "NumpyToTensor": continue # Update stg_transforms stg_transforms += transform_name + "_" # Add new transform to Compose dict_transforms.update({transform_name: training_transforms[transform_name]}) composed_transforms = imed_transforms.Compose(dict_transforms) # Loop across slices for i in indexes: data = [input_data[:, :, i]] # Init metadata metadata = SampleMetadata({MetadataKW.ZOOMS: zooms, MetadataKW.DATA_TYPE: "gt" if is_mask else "im"}) # Apply transformations to ROI if TransformationKW.CENTERCROP in training_transforms or \ (TransformationKW.ROICROP in training_transforms and Path(roi).is_file()): metadata.__setitem__(MetadataKW.CROP_PARAMS, {}) # Apply transformations to image stack_im, _ = composed_transforms(sample=data, metadata=[metadata for _ in range(number)], data_type="im") # Plot before / after transformation fname_out = str(Path(output, stg_transforms + "slice" + str(i) + ".png")) logger.debug(f"Fname out: {fname_out}.") logger.debug(f"\t{dict(metadata)}") # rescale intensities if len(stg_transforms[:-1].split("_")) == 1: before = np.rot90(imed_maths.rescale_values_array(data[0], 0.0, 1.0)) else: before = after if isinstance(stack_im[0], torch.Tensor): after = np.rot90(imed_maths.rescale_values_array(stack_im[0].numpy(), 0.0, 1.0)) else: after = np.rot90(imed_maths.rescale_values_array(stack_im[0], 0.0, 1.0)) # Plot imed_utils.plot_transformed_sample(before, after, list_title=["\n".join(stg_transforms[:-1].split("_")[:-1]), "\n".join(stg_transforms[:-1].split("_"))], fname_out=fname_out, cmap="jet" if is_mask else "gray")
def main(args=None): imed_utils.init_ivadomed() parser = get_parser() args = imed_utils.get_arguments(parser, args) run_visualization(input=args.input, config=args.config, number=int(args.number), output=args.output, roi=args.roi) if __name__ == '__main__': main()