import copy
from loguru import logger
from typing import List
from ivadomed import transforms as imed_transforms
from ivadomed import utils as imed_utils
from ivadomed.loader.bids3d_dataset import Bids3DDataset
from ivadomed.loader.bids_dataframe import BidsDataframe
from ivadomed.loader.bids_dataset import BidsDataset
from ivadomed.keywords import ROIParamsKW, TransformationKW, ModelParamsKW, ConfigKW
from ivadomed.loader.slice_filter import SliceFilter
from ivadomed.loader.patch_filter import PatchFilter
import torch
[docs]def load_dataset(bids_df: BidsDataframe,
data_list: List[str],
transforms_params: dict,
model_params: dict,
target_suffix: List[str],
roi_params: dict,
contrast_params: dict,
slice_filter_params: dict,
patch_filter_params: dict,
slice_axis: str,
multichannel: bool,
dataset_type: str = "training",
requires_undo: bool = False,
metadata_type: str = None,
object_detection_params: dict = None,
soft_gt: bool = False,
device: torch.device = None,
cuda_available: bool = None,
is_input_dropout: bool = False,
**kwargs) -> Bids3DDataset:
"""Get loader appropriate loader according to model type. Available loaders are Bids3DDataset for 3D data,
BidsDataset for 2D data and HDF5Dataset for HeMIS.
Args:
bids_df (BidsDataframe): Object containing dataframe with all BIDS image files and their metadata.
data_list (list): Subject names list.
transforms_params (dict): Dictionary containing transformations for "training", "validation", "testing" (keys),
eg output of imed_transforms.get_subdatasets_transforms.
model_params (dict): Dictionary containing model parameters.
target_suffix (list of str): List of suffixes for target masks.
roi_params (dict): Contains ROI related parameters.
contrast_params (dict): Contains image contrasts related parameters.
slice_filter_params (dict): Contains slice_filter_params, see :doc:`configuration_file` for more details.
patch_filter_params (dict): Contains patch_filter_params, see :doc:`configuration_file` for more details.
slice_axis (string): Choice between "axial", "sagittal", "coronal" ; controls the axis used to extract the 2D
data from 3D NifTI files. 2D PNG/TIF/JPG files use default "axial.
multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
contrast is processed individually (ie different sample / tensor).
metadata_type (str): Choice between None, "mri_params", "contrasts".
dataset_type (str): Choice between "training", "validation" or "testing".
requires_undo (bool): If True, the transformations without undo_transform will be discarded.
object_detection_params (dict): Object dection parameters.
soft_gt (bool): If True, ground truths are not binarized before being fed to the network. Otherwise, ground
truths are thresholded (0.5) after the data augmentation operations.
device (torch.device): Device to use for the model training.
cuda_available (bool): If True, cuda is available.
is_input_dropout (bool): Return input with missing modalities.
Returns:
BidsDataset
Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
slice_filter_params, patch_filter_params and object_detection_params see :doc:`configuration_file`.
"""
# Compose transforms
tranform_lst, _ = imed_transforms.prepare_transforms(copy.deepcopy(transforms_params), requires_undo)
# If ROICrop is not part of the transforms, then enforce no slice filtering based on ROI data.
if TransformationKW.ROICROP not in transforms_params:
roi_params[ROIParamsKW.SLICE_FILTER_ROI] = None
if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET \
or (ModelParamsKW.IS_2D in model_params and not model_params[ModelParamsKW.IS_2D]):
dataset = Bids3DDataset(bids_df=bids_df,
subject_file_lst=data_list,
target_suffix=target_suffix,
roi_params=roi_params,
contrast_params=contrast_params,
metadata_choice=metadata_type,
slice_axis=imed_utils.AXIS_DCT[slice_axis],
transform=tranform_lst,
multichannel=multichannel,
subvolume_filter_fn=PatchFilter(**patch_filter_params, is_train=False if dataset_type == "testing" else True),
model_params=model_params,
object_detection_params=object_detection_params,
soft_gt=soft_gt,
is_input_dropout=is_input_dropout)
else:
# Task selection
task = imed_utils.get_task(model_params[ModelParamsKW.NAME])
dataset = BidsDataset(bids_df=bids_df,
subject_file_lst=data_list,
target_suffix=target_suffix,
roi_params=roi_params,
contrast_params=contrast_params,
model_params=model_params,
metadata_choice=metadata_type,
slice_axis=imed_utils.AXIS_DCT[slice_axis],
transform=tranform_lst,
multichannel=multichannel,
slice_filter_fn=SliceFilter(**slice_filter_params, device=device,
cuda_available=cuda_available),
patch_filter_fn=PatchFilter(**patch_filter_params,
is_train=False if dataset_type == "testing" else True),
soft_gt=soft_gt,
object_detection_params=object_detection_params,
task=task,
is_input_dropout=is_input_dropout)
dataset.load_filenames()
if model_params[ModelParamsKW.NAME] == ConfigKW.MODIFIED_3D_UNET:
logger.info(f"Loaded {len(dataset)} volumes of shape {dataset.length} for the {dataset_type} set.")
elif model_params[ModelParamsKW.NAME] != ConfigKW.HEMIS_UNET and dataset.length:
logger.info(f"Loaded {len(dataset)} {slice_axis} patches of shape {dataset.length} for the {dataset_type} set.")
else:
logger.info(f"Loaded {len(dataset)} {slice_axis} slices for the { dataset_type} set.")
return dataset