Source code for ivadomed.transforms

import copy
import functools
import math
import numbers
import random

import numpy as np
import torch
from scipy.ndimage import zoom
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates, affine_transform
from scipy.ndimage.measurements import label, center_of_mass
from scipy.ndimage.morphology import binary_dilation, binary_fill_holes, binary_closing
from skimage.exposure import equalize_adapthist
from torchvision import transforms as torchvision_transforms

from ivadomed.loader import utils as imed_loader_utils


[docs]def multichannel_capable(wrapped): """Decorator to make a given function compatible multichannel images. Args: wrapped: Given function. Returns: Functions' return. """ @functools.wraps(wrapped) def wrapper(self, sample, metadata): if isinstance(sample, list): list_data, list_metadata = [], [] for s_cur, m_cur in zip(sample, metadata): if len(list_metadata) > 0: imed_loader_utils.update_metadata([list_metadata[-1]], [m_cur]) # Run function for each sample of the list data_cur, metadata_cur = wrapped(self, s_cur, m_cur) list_data.append(data_cur) list_metadata.append(metadata_cur) return list_data, list_metadata # If sample is None, then return a pair (None, None) if sample is None: return None, None else: return wrapped(self, sample, metadata) return wrapper
[docs]def two_dim_compatible(wrapped): """Decorator to make a given function compatible 2D or 3D images. Args: wrapped: Given function. Returns: Functions' return. """ @functools.wraps(wrapped) def wrapper(self, sample, metadata): # Check if sample is 2D if len(sample.shape) == 2: # Add one dimension sample = np.expand_dims(sample, axis=-1) # Run transform result_sample, result_metadata = wrapped(self, sample, metadata) # Remove last dimension return np.squeeze(result_sample, axis=-1), result_metadata else: return wrapped(self, sample, metadata) return wrapper
[docs]class ImedTransform(object): """Base class for transforamtions."""
[docs] def __call__(self, sample, metadata=None): raise NotImplementedError("You need to implement the transform() method.")
[docs]class Compose(object): """Composes transforms together. Composes transforms together and split between images, GT and ROI. self.transform is a dict: - keys: "im", "gt" and "roi" - values torchvision_transform.Compose objects. Attributes: dict_transforms (dict): Dictionary where the keys are the transform names and the value their parameters. requires_undo (bool): If True, does not include transforms which do not have an undo_transform implemented yet. Args: transform (dict): Keys are "im", "gt", "roi" and values are torchvision_transforms.Compose of the transformations of interest. """
[docs] def __init__(self, dict_transforms, requires_undo=False): list_tr_im, list_tr_gt, list_tr_roi = [], [], [] for transform in dict_transforms.keys(): parameters = dict_transforms[transform] # Get list of data type if "applied_to" in parameters: list_applied_to = parameters["applied_to"] else: list_applied_to = ["im", "gt", "roi"] # call transform if transform in globals(): params_cur = {k: parameters[k] for k in parameters if k != "applied_to" and k != "preprocessing"} transform_obj = globals()[transform](**params_cur) else: raise ValueError('ERROR: {} transform is not available. ' 'Please check its compatibility with your model json file.'.format(transform)) # check if undo_transform method is implemented if requires_undo: if not hasattr(transform_obj, 'undo_transform'): print('{} transform not included since no undo_transform available for it.'.format(transform)) continue if "im" in list_applied_to: list_tr_im.append(transform_obj) if "roi" in list_applied_to: list_tr_roi.append(transform_obj) if "gt" in list_applied_to: list_tr_gt.append(transform_obj) self.transform = { "im": torchvision_transforms.Compose(list_tr_im), "gt": torchvision_transforms.Compose(list_tr_gt), "roi": torchvision_transforms.Compose(list_tr_roi)}
[docs] def __call__(self, sample, metadata, data_type='im'): if self.transform[data_type] is None or len(metadata) == 0: # In case self.transform[data_type] is None return None, None else: for tr in self.transform[data_type].transforms: sample, metadata = tr(sample, metadata) return sample, metadata
[docs]class UndoCompose(object): """Undo the Compose transformations. Call the undo transformations in the inverse order than the "do transformations". Attributes: compose (torchvision_transforms.Compose): Args: transforms (torchvision_transforms.Compose): """
[docs] def __init__(self, compose): self.transforms = compose
[docs] def __call__(self, sample, metadata, data_type='gt'): if self.transforms.transform[data_type] is None: # In case self.transforms.transform[data_type] is None return None, None else: for tr in self.transforms.transform[data_type].transforms[::-1]: sample, metadata = tr.undo_transform(sample, metadata) return sample, metadata
[docs]class UndoTransform(object): """Call undo transformation. Attributes: transform (ImedTransform): Args: transform (ImedTransform): """
[docs] def __init__(self, transform): self.transform = transform
[docs] def __call__(self, sample): return self.transform.undo_transform(sample)
[docs]class NumpyToTensor(ImedTransform): """Converts nd array to tensor object."""
[docs] def undo_transform(self, sample, metadata=None): """Converts Tensor to nd array.""" return list(sample.numpy()), metadata
[docs] def __call__(self, sample, metadata=None): """Converts nd array to Tensor.""" sample = np.array(sample) # Use np.ascontiguousarray to avoid axes permutations issues arr_contig = np.ascontiguousarray(sample, dtype=sample.dtype) return torch.from_numpy(arr_contig), metadata
[docs]class Resample(ImedTransform): """ Resample image to a given resolution. Args: hspace (float): Resolution along the first axis, in mm. wspace (float): Resolution along the second axis, in mm. dspace (float): Resolution along the third axis, in mm. interpolation_order (int): Order of spline interpolation. Set to 0 for label data. Default=2. """
[docs] def __init__(self, hspace, wspace, dspace=1.): self.hspace = hspace self.wspace = wspace self.dspace = dspace
[docs] @multichannel_capable @two_dim_compatible def undo_transform(self, sample, metadata=None): """Resample to original resolution.""" assert "data_shape" in metadata is_2d = sample.shape[-1] == 1 # Get params original_shape = metadata["preresample_shape"] current_shape = sample.shape params_undo = [x / y for x, y in zip(original_shape, current_shape)] if is_2d: params_undo[-1] = 1.0 # Undo resampling data_out = zoom(sample, zoom=params_undo, order=1 if metadata['data_type'] == 'gt' else 2) # Data type data_out = data_out.astype(sample.dtype) return data_out, metadata
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): """Resample to a given resolution, in millimeters.""" # Get params # Voxel dimension in mm is_2d = sample.shape[-1] == 1 metadata['preresample_shape'] = sample.shape zooms = list(metadata["zooms"]) if len(zooms) == 2: zooms += [1.0] hfactor = zooms[0] / self.hspace wfactor = zooms[1] / self.wspace dfactor = zooms[2] / self.dspace params_resample = (hfactor, wfactor, dfactor) if not is_2d else (hfactor, wfactor, 1.0) # Run resampling data_out = zoom(sample, zoom=params_resample, order=1 if metadata['data_type'] == 'gt' else 2) # Data type data_out = data_out.astype(sample.dtype) return data_out, metadata
[docs]class NormalizeInstance(ImedTransform): """Normalize a tensor or an array image with mean and standard deviation estimated from the sample itself.""" @multichannel_capable def undo_transform(self, sample, metadata=None): # Nothing return sample, metadata
[docs] @multichannel_capable def __call__(self, sample, metadata=None): data_out = (sample - sample.mean()) / sample.std() return data_out, metadata
[docs]class CroppableArray(np.ndarray): """Zero padding slice past end of array in numpy. Adapted From: https://stackoverflow.com/a/41155020/13306686 """
[docs] def __getitem__(self, item): all_in_slices = [] pad = [] for dim in range(self.ndim): # If the slice has no length then it's a single argument. # If it's just an integer then we just return, this is # needed for the representation to work properly # If it's not then create a list containing None-slices # for dim>=1 and continue down the loop try: len(item) except TypeError: if isinstance(item, int): return super().__getitem__(item) newitem = [slice(None)] * self.ndim newitem[0] = item item = newitem # We're out of items, just append noop slices if dim >= len(item): all_in_slices.append(slice(0, self.shape[dim])) pad.append((0, 0)) # We're dealing with an integer (no padding even if it's # out of bounds) if isinstance(item[dim], int): all_in_slices.append(slice(item[dim], item[dim] + 1)) pad.append((0, 0)) # Dealing with a slice, here it get's complicated, we need # to correctly deal with None start/stop as well as with # out-of-bound values and correct padding elif isinstance(item[dim], slice): # Placeholders for values start, stop = 0, self.shape[dim] this_pad = [0, 0] if item[dim].start is None: start = 0 else: if item[dim].start < 0: this_pad[0] = -item[dim].start start = 0 else: start = item[dim].start if item[dim].stop is None: stop = self.shape[dim] else: if item[dim].stop > self.shape[dim]: this_pad[1] = item[dim].stop - self.shape[dim] stop = self.shape[dim] else: stop = item[dim].stop all_in_slices.append(slice(start, stop)) pad.append(tuple(this_pad)) # Let numpy deal with slicing ret = super().__getitem__(tuple(all_in_slices)) # and padding ret = np.pad(ret, tuple(pad), mode='constant', constant_values=0) return ret
[docs]class Crop(ImedTransform): """Crop data. Args: size (tuple of int): Size of the output sample. Tuple of size 2 if dealing with 2D samples, 3 with 3D samples. Attributes: size (tuple of int): Size of the output sample. Tuple of size 3. """
[docs] def __init__(self, size): self.size = size if len(size) == 3 else size + [0]
@staticmethod def _adjust_padding(npad, sample): npad_out_tuple = [] for idx_dim, tuple_pad in enumerate(npad): pad_start, pad_end = tuple_pad if pad_start < 0 or pad_end < 0: # Move axis of interest sample_reorient = np.swapaxes(sample, 0, idx_dim) # Adjust pad and crop if pad_start < 0 and pad_end < 0: sample_crop = sample_reorient[abs(pad_start):pad_end, ] pad_end, pad_start = 0, 0 elif pad_start < 0: sample_crop = sample_reorient[abs(pad_start):, ] pad_start = 0 else: # i.e. pad_end < 0: sample_crop = sample_reorient[:pad_end, ] pad_end = 0 # Reorient sample = np.swapaxes(sample_crop, 0, idx_dim) npad_out_tuple.append((pad_start, pad_end)) return npad_out_tuple, sample
[docs] @multichannel_capable def __call__(self, sample, metadata): # Get params is_2d = sample.shape[-1] == 1 th, tw, td = self.size fh, fw, fd, h, w, d = metadata['crop_params'][self.__class__.__name__] # Crop data # Note we use here CroppableArray in order to deal with "out of boundaries" crop # e.g. if fh is negative or fh+th out of bounds, then it will pad if is_2d: data_out = sample.view(CroppableArray)[fh:fh + th, fw:fw + tw, :] else: data_out = sample.view(CroppableArray)[fh:fh + th, fw:fw + tw, fd:fd + td] return data_out, metadata
@multichannel_capable @two_dim_compatible def undo_transform(self, sample, metadata=None): # Get crop params is_2d = sample.shape[-1] == 1 th, tw, td = self.size fh, fw, fd, h, w, d = metadata["crop_params"][self.__class__.__name__] # Compute params to undo transform pad_left = fw pad_right = w - pad_left - tw pad_top = fh pad_bottom = h - pad_top - th pad_front = fd if not is_2d else 0 pad_back = d - pad_front - td if not is_2d else 0 npad = [(pad_top, pad_bottom), (pad_left, pad_right), (pad_front, pad_back)] # Check and adjust npad if needed, i.e. if crop out of boundaries npad_adj, sample_adj = self._adjust_padding(npad, sample.copy()) # Apply padding data_out = np.pad(sample_adj, npad_adj, mode='constant', constant_values=0).astype(sample.dtype) return data_out, metadata
[docs]class CenterCrop(Crop): """Make a centered crop of a specified size."""
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): # Crop parameters th, tw, td = self.size h, w, d = sample.shape fh = int(round((h - th) / 2.)) fw = int(round((w - tw) / 2.)) fd = int(round((d - td) / 2.)) params = (fh, fw, fd, h, w, d) metadata['crop_params'][self.__class__.__name__] = params # Call base method return super().__call__(sample, metadata)
[docs]class ROICrop(Crop): """Make a crop of a specified size around a Region of Interest (ROI)."""
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): # If crop_params are not in metadata, # then we are here dealing with ROI data to determine crop params if self.__class__.__name__ not in metadata['crop_params']: # Compute center of mass of the ROI h_roi, w_roi, d_roi = center_of_mass(sample.astype(np.int)) h_roi, w_roi, d_roi = int(round(h_roi)), int(round(w_roi)), int(round(d_roi)) th, tw, td = self.size th_half, tw_half, td_half = int(round(th / 2.)), int(round(tw / 2.)), int(round(td / 2.)) # compute top left corner of the crop area fh = h_roi - th_half fw = w_roi - tw_half fd = d_roi - td_half # Crop params h, w, d = sample.shape params = (fh, fw, fd, h, w, d) metadata['crop_params'][self.__class__.__name__] = params # Call base method return super().__call__(sample, metadata)
[docs]class DilateGT(ImedTransform): """Randomly dilate a ground-truth tensor. .. image:: ../../images/dilate-gt.png :width: 600px :align: center Args: dilation_factor (float): Controls the number of dilation iterations. For each individual lesion, the number of dilation iterations is computed as follows: nb_it = int(round(dilation_factor * sqrt(lesion_area))) If dilation_factor <= 0, then no dilation will be performed. """
[docs] def __init__(self, dilation_factor): self.dil_factor = dilation_factor
@staticmethod def dilate_lesion(arr_bin, arr_soft, label_values): for lb in label_values: # binary dilation with 1 iteration arr_dilated = binary_dilation(arr_bin, iterations=1) # isolate new voxels, i.e. the ones from the dilation new_voxels = np.logical_xor(arr_dilated, arr_bin).astype(np.int) # assign a soft value (]0, 1[) to the new voxels soft_new_voxels = lb * new_voxels # add the new voxels to the input mask arr_soft += soft_new_voxels arr_bin = (arr_soft > 0).astype(np.int) return arr_bin, arr_soft def dilate_arr(self, arr, dil_factor): # identify each object arr_labeled, lb_nb = label(arr.astype(np.int)) # loop across each object arr_bin_lst, arr_soft_lst = [], [] for obj_idx in range(1, lb_nb + 1): arr_bin_obj = (arr_labeled == obj_idx).astype(np.int) arr_soft_obj = np.copy(arr_bin_obj).astype(np.float) # compute the number of dilation iterations depending on the size of the lesion nb_it = int(round(dil_factor * math.sqrt(arr_bin_obj.sum()))) # values of the voxels added to the input mask soft_label_values = [x / (nb_it + 1) for x in range(nb_it, 0, -1)] # dilate lesion arr_bin_dil, arr_soft_dil = self.dilate_lesion(arr_bin_obj, arr_soft_obj, soft_label_values) arr_bin_lst.append(arr_bin_dil) arr_soft_lst.append(arr_soft_dil) # sum dilated objects arr_bin_idx = np.sum(np.array(arr_bin_lst), axis=0) arr_soft_idx = np.sum(np.array(arr_soft_lst), axis=0) # clip values in case dilated voxels overlap arr_bin_clip, arr_soft_clip = np.clip(arr_bin_idx, 0, 1), np.clip(arr_soft_idx, 0.0, 1.0) return arr_soft_clip.astype(np.float), arr_bin_clip.astype(np.int) @staticmethod def random_holes(arr_in, arr_soft, arr_bin): arr_soft_out = np.copy(arr_soft) # coordinates of the new voxels, i.e. the ones from the dilation new_voxels_xx, new_voxels_yy, new_voxels_zz = np.where(np.logical_xor(arr_bin, arr_in)) nb_new_voxels = new_voxels_xx.shape[0] # ratio of voxels added to the input mask from the dilated mask new_voxel_ratio = random.random() # randomly select new voxel indexes to remove idx_to_remove = random.sample(range(nb_new_voxels), int(round(nb_new_voxels * (1 - new_voxel_ratio)))) # set to zero the here-above randomly selected new voxels arr_soft_out[new_voxels_xx[idx_to_remove], new_voxels_yy[idx_to_remove], new_voxels_zz[idx_to_remove]] = 0.0 arr_bin_out = (arr_soft_out > 0).astype(np.int) return arr_soft_out, arr_bin_out @staticmethod def post_processing(arr_in, arr_soft, arr_bin, arr_dil): # remove new object that are not connected to the input mask arr_labeled, lb_nb = label(arr_bin) connected_to_in = arr_labeled * arr_in for lb in range(1, lb_nb + 1): if np.sum(connected_to_in == lb) == 0: arr_soft[arr_labeled == lb] = 0 struct = np.ones((3, 3, 1) if arr_soft.shape[2] == 1 else (3, 3, 3)) # binary closing arr_bin_closed = binary_closing((arr_soft > 0).astype(np.int), structure=struct) # fill binary holes arr_bin_filled = binary_fill_holes(arr_bin_closed) # recover the soft-value assigned to the filled-holes arr_soft_out = arr_bin_filled * arr_dil return arr_soft_out
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): # binarize for processing gt_data_np = (sample > 0.5).astype(np.int_) if self.dil_factor > 0 and np.sum(sample): # dilation gt_dil, gt_dil_bin = self.dilate_arr(gt_data_np, self.dil_factor) # random holes in dilated area # gt_holes, gt_holes_bin = self.random_holes(gt_data_np, gt_dil, gt_dil_bin) # post-processing # gt_pp = self.post_processing(gt_data_np, gt_holes, gt_holes_bin, gt_dil) # return gt_pp.astype(np.float32), metadata return gt_dil.astype(np.float32), metadata else: return sample, metadata
[docs]class BoundingBoxCrop(Crop): """Crops image according to given bounding box."""
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata): assert 'bounding_box' in metadata x_min, x_max, y_min, y_max, z_min, z_max = metadata['bounding_box'] x, y, z = sample.shape metadata['crop_params'][self.__class__.__name__] = (x_min, y_min, z_min, x, y, z) # Call base method return super().__call__(sample, metadata)
[docs]class RandomAffine(ImedTransform): """Apply Random Affine transformation. Args: degrees (float): Positive float or list (or tuple) of length two. Angles in degrees. If only a float is provided, then rotation angle is selected within the range [-degrees, degrees]. Otherwise, the list / tuple defines this range. translate (list of float): List of floats between 0 and 1, of length 2 or 3 depending on the sample shape (2D or 3D). These floats defines the maximum range of translation along each axis. scale (list of float): List of floats between 0 and 1, of length 2 or 3 depending on the sample shape (2D or 3D). These floats defines the maximum range of scaling along each axis. Attributes: degrees (tuple of floats): translate (list of float): scale (list of float): """
[docs] def __init__(self, degrees=0, translate=None, scale=None): # Rotation if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") self.degrees = (-degrees, degrees) else: assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ "degrees should be a list or tuple and it must be of length 2." self.degrees = degrees # Scale if scale is not None: assert isinstance(scale, (tuple, list)) and (len(scale) == 2 or len(scale) == 3), \ "scale should be a list or tuple and it must be of length 2 or 3." for s in scale: if not (0.0 <= s <= 1.0): raise ValueError("scale values should be between 0 and 1") if len(scale) == 2: scale.append(0.0) self.scale = scale else: self.scale = [0., 0., 0.] # Translation if translate is not None: assert isinstance(translate, (tuple, list)) and (len(translate) == 2 or len(translate) == 3), \ "translate should be a list or tuple and it must be of length 2 or 3." for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") if len(translate) == 2: translate.append(0.0) self.translate = translate
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): # Rotation # If angle and metadata have been already defined for this sample, then use them if 'rotation' in metadata: angle, axes = metadata['rotation'] # Otherwise, get random ones else: # Get the random angle angle = math.radians(np.random.uniform(self.degrees[0], self.degrees[1])) # Get the two axes that define the plane of rotation axes = list(random.sample(range(3 if sample.shape[2] > 1 else 2), 2)) axes.sort() # Save params metadata['rotation'] = [angle, axes] # Scale if "scale" in metadata: scale_x, scale_y, scale_z = metadata['scale'] else: scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0]) scale_y = random.uniform(1 - self.scale[1], 1 + self.scale[1]) scale_z = random.uniform(1 - self.scale[2], 1 + self.scale[2]) metadata['scale'] = [scale_x, scale_y, scale_z] # Get params if 'translation' in metadata: translations = metadata['translation'] else: self.data_shape = sample.shape if self.translate is not None: max_dx = self.translate[0] * self.data_shape[0] max_dy = self.translate[1] * self.data_shape[1] max_dz = self.translate[2] * self.data_shape[2] translations = (np.round(np.random.uniform(-max_dx, max_dx)), np.round(np.random.uniform(-max_dy, max_dy)), np.round(np.random.uniform(-max_dz, max_dz))) else: translations = (0, 0, 0) metadata['translation'] = translations # Do rotation shape = 0.5 * np.array(sample.shape) if axes == [0, 1]: rotate = np.array([[math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1]]) elif axes == [0, 2]: rotate = np.array([[math.cos(angle), 0, math.sin(angle)], [0, 1, 0], [-math.sin(angle), 0, math.cos(angle)]]) elif axes == [1, 2]: rotate = np.array([[1, 0, 0], [0, math.cos(angle), -math.sin(angle)], [0, math.sin(angle), math.cos(angle)]]) else: raise ValueError("Unknown axes value") scale = np.array([[1 / scale_x, 0, 0], [0, 1 / scale_y, 0], [0, 0, 1 / scale_z]]) if "undo" in metadata and metadata["undo"]: transforms = scale.dot(rotate) else: transforms = rotate.dot(scale) offset = shape - shape.dot(transforms) + translations data_out = affine_transform(sample, transforms.T, order=1, offset=offset, output_shape=sample.shape).astype(sample.dtype) return data_out, metadata
@multichannel_capable @two_dim_compatible def undo_transform(self, sample, metadata=None): assert "rotation" in metadata assert "scale" in metadata assert "translation" in metadata # Opposite rotation, same axes angle, axes = - metadata['rotation'][0], metadata['rotation'][1] scale = 1 / np.array(metadata['scale']) translation = - np.array(metadata['translation']) # Undo rotation dict_params = {"rotation": [angle, axes], "scale": scale, "translation": [0, 0, 0], "undo": True} data_out, _ = self.__call__(sample, dict_params) data_out = affine_transform(data_out, np.identity(3), order=1, offset=translation, output_shape=sample.shape).astype(sample.dtype) return data_out, metadata
[docs]class RandomReverse(ImedTransform): """Make a randomized symmetric inversion of the different values of each dimensions."""
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): if 'reverse' in metadata: flip_axes = metadata['reverse'] else: # Flip axis booleans flip_axes = [np.random.randint(2) == 1 for _ in [0, 1, 2]] # Save in metadata metadata['reverse'] = flip_axes # Run flip for idx_axis, flip_bool in enumerate(flip_axes): if flip_axes: sample = np.flip(sample, axis=idx_axis).copy() return sample, metadata
@multichannel_capable @two_dim_compatible def undo_transform(self, sample, metadata=None): assert "reverse" in metadata return self.__call__(sample, metadata)
[docs]class RandomShiftIntensity(ImedTransform): """Add a random intensity offset. Args: shift_range (tuple of floats): Tuple of length two. Specifies the range where the offset that is applied is randomly selected from. prob (float): Between 0 and 1. Probability of occurence of this transformation. """
[docs] def __init__(self, shift_range, prob=0.1): self.shift_range = shift_range self.prob = prob
[docs] @multichannel_capable def __call__(self, sample, metadata=None): if np.random.random() < self.prob: # Get random offset offset = np.random.uniform(self.shift_range[0], self.shift_range[1]) else: offset = 0.0 # Update metadata metadata['offset'] = offset # Shift intensity data = (sample + offset).astype(sample.dtype) return data, metadata
@multichannel_capable def undo_transform(self, sample, metadata=None): assert 'offset' in metadata # Get offset offset = metadata['offset'] # Substract offset data = (sample - offset).astype(sample.dtype) return data, metadata
[docs]class ElasticTransform(ImedTransform): """Applies elastic transformation. .. seealso:: Simard, Patrice Y., David Steinkraus, and John C. Platt. "Best practices for convolutional neural networks applied to visual document analysis." Icdar. Vol. 3. No. 2003. 2003. Args: alpha_range (tuple of floats): Deformation coefficient. Length equals 2. sigma_range (tuple of floats): Standard deviation. Length equals 2. """
[docs] def __init__(self, alpha_range, sigma_range, p=0.1): self.alpha_range = alpha_range self.sigma_range = sigma_range self.p = p
[docs] @multichannel_capable @two_dim_compatible def __call__(self, sample, metadata=None): # if params already defined, i.e. sample is GT if "elastic" in metadata: alpha, sigma = metadata["elastic"] elif np.random.random() < self.p: # Get params alpha = np.random.uniform(self.alpha_range[0], self.alpha_range[1]) sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1]) # Save params metadata["elastic"] = [alpha, sigma] else: metadata["elastic"] = [None, None] if any(metadata["elastic"]): # Get shape shape = sample.shape # Compute random deformation dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha if shape[2] == 1: dz = 0 # No deformation along the last dimension x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij') indices = np.reshape(x + dx, (-1, 1)), \ np.reshape(y + dy, (-1, 1)), \ np.reshape(z + dz, (-1, 1)) # Apply deformation data_out = map_coordinates(sample, indices, order=1, mode='reflect') # Keep input shape data_out = data_out.reshape(shape) # Keep data type data_out = data_out.astype(sample.dtype) return data_out, metadata else: return sample, metadata
[docs]class AdditiveGaussianNoise(ImedTransform): """Adds Gaussian Noise to images. Args: mean (float): Gaussian noise mean. std (float): Gaussian noise standard deviation. """
[docs] def __init__(self, mean=0.0, std=0.01): self.mean = mean self.std = std
[docs] @multichannel_capable def __call__(self, sample, metadata=None): if "gaussian_noise" in metadata: noise = metadata["gaussian_noise"] else: # Get random noise noise = np.random.normal(self.mean, self.std, sample.shape) noise = noise.astype(np.float32) # Apply noise data_out = sample + noise return data_out.astype(sample.dtype), metadata
[docs]class Clahe(ImedTransform): """ Applies Contrast Limited Adaptive Histogram Equalization for enhancing the local image contrast. .. seealso:: Zuiderveld, Karel. "Contrast limited adaptive histogram equalization." Graphics gems (1994): 474-485. Default values are based on: .. seealso:: Zheng, Qiao, et al. "3-D consistent and robust segmentation of cardiac images by deep learning with spatial propagation." IEEE transactions on medical imaging 37.9 (2018): 2137-2148. Args: clip_limit (float): Clipping limit, normalized between 0 and 1. kernel_size (tuple of int): Defines the shape of contextual regions used in the algorithm. Length equals image dimension (ie 2 or 3 for 2D or 3D, respectively). """
[docs] def __init__(self, clip_limit=3.0, kernel_size=(8, 8)): self.clip_limit = clip_limit self.kernel_size = kernel_size
[docs] @multichannel_capable def __call__(self, sample, metadata=None): assert len(self.kernel_size) == len(sample.shape) # Run equalization data_out = equalize_adapthist(sample, kernel_size=self.kernel_size, clip_limit=self.clip_limit).astype(sample.dtype) return data_out, metadata
[docs]class HistogramClipping(ImedTransform): """Performs intensity clipping based on percentiles. Args: min_percentile (float): Between 0 and 100. Lower clipping limit. max_percentile (float): Between 0 and 100. Higher clipping limit. """
[docs] def __init__(self, min_percentile=5.0, max_percentile=95.0): self.min_percentile = min_percentile self.max_percentile = max_percentile
[docs] @multichannel_capable def __call__(self, sample, metadata=None): data = np.copy(sample) # Run clipping percentile1 = np.percentile(sample, self.min_percentile) percentile2 = np.percentile(sample, self.max_percentile) data[sample <= percentile1] = percentile1 data[sample >= percentile2] = percentile2 return data, metadata
[docs]def get_subdatasets_transforms(transform_params): """Get transformation parameters for each subdataset: training, validation and testing. Args: transform_params (dict): Returns: dict, dict, dict: Training, Validation and Testing transformations. """ transform_params = copy.deepcopy(transform_params) train, valid, test = {}, {}, {} subdataset_default = ["training", "validation", "testing"] # Loop across transformations for transform_name in transform_params: subdataset_list = ["training", "validation", "testing"] # Only consider subdatasets listed in dataset_type if "dataset_type" in transform_params[transform_name]: subdataset_list = transform_params[transform_name]["dataset_type"] # Add current transformation to the relevant subdataset transformation dictionaries for subds_name, subds_dict in zip(subdataset_default, [train, valid, test]): if subds_name in subdataset_list: subds_dict[transform_name] = transform_params[transform_name] if "dataset_type" in subds_dict[transform_name]: del subds_dict[transform_name]["dataset_type"] return train, valid, test
[docs]def get_preprocessing_transforms(transforms): """Checks the transformations parameters and selects the transformations which are done during preprocessing only. Args: transforms (dict): Transformation dictionary. Returns: dict: Preprocessing transforms. """ original_transforms = copy.deepcopy(transforms) preprocessing_transforms = copy.deepcopy(transforms) for idx, tr in enumerate(original_transforms): if tr == "Resample" or tr == "CenterCrop" or tr == "ROICrop": del transforms[tr] else: del preprocessing_transforms[tr] return preprocessing_transforms
[docs]def apply_preprocessing_transforms(transforms, seg_pair, roi_pair=None): """ Applies preprocessing transforms to segmentation pair (input, gt and metadata). Args: transforms (Compose): Preprocessing transforms. seg_pair (dict): Segmentation pair containing input and gt. roi_pair (dict): Segementation pair containing input and roi. Returns: tuple: Segmentation pair and roi pair. """ if transforms is None: return (seg_pair, roi_pair) metadata_input = seg_pair['input_metadata'] if roi_pair is not None: stack_roi, metadata_roi = transforms(sample=roi_pair["gt"], metadata=roi_pair['gt_metadata'], data_type="roi") metadata_input = imed_loader_utils.update_metadata(metadata_roi, metadata_input) # Run transforms on images stack_input, metadata_input = transforms(sample=seg_pair["input"], metadata=metadata_input, data_type="im") # Run transforms on images metadata_gt = imed_loader_utils.update_metadata(metadata_input, seg_pair['gt_metadata']) stack_gt, metadata_gt = transforms(sample=seg_pair["gt"], metadata=metadata_gt, data_type="gt") seg_pair = { 'input': stack_input, 'gt': stack_gt, 'input_metadata': metadata_input, 'gt_metadata': metadata_gt } if roi_pair is not None and len(roi_pair['gt']): roi_pair = { 'input': stack_input, 'gt': stack_roi, 'input_metadata': metadata_input, 'gt_metadata': metadata_roi } return (seg_pair, roi_pair)
[docs]def prepare_transforms(transform_dict, requires_undo=True): """ This function separates the preprocessing transforms from the others and generates the undo transforms related. Args: transform_dict (dict): Dictionary containing the transforms and there parameters. requires_undo (bool): Boolean indicating if transforms can be undone. Returns: list, UndoCompose: transform lst containing the preprocessing transforms and regular transforms, UndoCompose object containing the transform to undo. """ training_undo_transform = None if requires_undo: training_undo_transform = UndoCompose(Compose(transform_dict.copy())) preprocessing_transforms = get_preprocessing_transforms(transform_dict) prepro_transforms = Compose(preprocessing_transforms, requires_undo=requires_undo) transforms = Compose(transform_dict, requires_undo=requires_undo) tranform_lst = [prepro_transforms if len(preprocessing_transforms) else None, transforms] return tranform_lst, training_undo_transform