import copy
import functools
import math
import numbers
import random
import numpy as np
import torch
from scipy.ndimage import zoom
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates, affine_transform
from scipy.ndimage.measurements import label, center_of_mass
from scipy.ndimage.morphology import binary_dilation, binary_fill_holes, binary_closing
from skimage.exposure import equalize_adapthist
from torchvision import transforms as torchvision_transforms
from ivadomed.loader import utils as imed_loader_utils
[docs]def multichannel_capable(wrapped):
"""Decorator to make a given function compatible multichannel images.
Args:
wrapped: Given function.
Returns:
Functions' return.
"""
@functools.wraps(wrapped)
def wrapper(self, sample, metadata):
if isinstance(sample, list):
list_data, list_metadata = [], []
for s_cur, m_cur in zip(sample, metadata):
if len(list_metadata) > 0:
if not isinstance(list_metadata[-1], list):
imed_loader_utils.update_metadata([list_metadata[-1]], [m_cur])
else:
imed_loader_utils.update_metadata(list_metadata[-1], [m_cur])
# Run function for each sample of the list
data_cur, metadata_cur = wrapped(self, s_cur, m_cur)
list_data.append(data_cur)
list_metadata.append(metadata_cur)
return list_data, list_metadata
# If sample is None, then return a pair (None, None)
if sample is None:
return None, None
else:
return wrapped(self, sample, metadata)
return wrapper
[docs]def two_dim_compatible(wrapped):
"""Decorator to make a given function compatible 2D or 3D images.
Args:
wrapped: Given function.
Returns:
Functions' return.
"""
@functools.wraps(wrapped)
def wrapper(self, sample, metadata):
# Check if sample is 2D
if len(sample.shape) == 2:
# Add one dimension
sample = np.expand_dims(sample, axis=-1)
# Run transform
result_sample, result_metadata = wrapped(self, sample, metadata)
# Remove last dimension
return np.squeeze(result_sample, axis=-1), result_metadata
else:
return wrapped(self, sample, metadata)
return wrapper
[docs]class Compose(object):
"""Composes transforms together.
Composes transforms together and split between images, GT and ROI.
self.transform is a dict:
- keys: "im", "gt" and "roi"
- values torchvision_transform.Compose objects.
Attributes:
dict_transforms (dict): Dictionary where the keys are the transform names
and the value their parameters.
requires_undo (bool): If True, does not include transforms which do not have an undo_transform
implemented yet.
Args:
transform (dict): Keys are "im", "gt", "roi" and values are torchvision_transforms.Compose of the
transformations of interest.
"""
[docs] def __init__(self, dict_transforms, requires_undo=False):
list_tr_im, list_tr_gt, list_tr_roi = [], [], []
for transform in dict_transforms.keys():
parameters = dict_transforms[transform]
# Get list of data type
if "applied_to" in parameters:
list_applied_to = parameters["applied_to"]
else:
list_applied_to = ["im", "gt", "roi"]
# call transform
if transform in globals():
params_cur = {k: parameters[k] for k in parameters if k != "applied_to" and k != "preprocessing"}
transform_obj = globals()[transform](**params_cur)
else:
raise ValueError('ERROR: {} transform is not available. '
'Please check its compatibility with your model json file.'.format(transform))
# check if undo_transform method is implemented
if requires_undo:
if not hasattr(transform_obj, 'undo_transform'):
print('{} transform not included since no undo_transform available for it.'.format(transform))
continue
if "im" in list_applied_to:
list_tr_im.append(transform_obj)
if "roi" in list_applied_to:
list_tr_roi.append(transform_obj)
if "gt" in list_applied_to:
list_tr_gt.append(transform_obj)
self.transform = {
"im": torchvision_transforms.Compose(list_tr_im),
"gt": torchvision_transforms.Compose(list_tr_gt),
"roi": torchvision_transforms.Compose(list_tr_roi)}
[docs] def __call__(self, sample, metadata, data_type='im'):
if self.transform[data_type] is None or len(metadata) == 0:
# In case self.transform[data_type] is None
return None, None
else:
for tr in self.transform[data_type].transforms:
sample, metadata = tr(sample, metadata)
return sample, metadata
[docs]class UndoCompose(object):
"""Undo the Compose transformations.
Call the undo transformations in the inverse order than the "do transformations".
Attributes:
compose (torchvision_transforms.Compose):
Args:
transforms (torchvision_transforms.Compose):
"""
[docs] def __init__(self, compose):
self.transforms = compose
[docs] def __call__(self, sample, metadata, data_type='gt'):
if self.transforms.transform[data_type] is None:
# In case self.transforms.transform[data_type] is None
return None, None
else:
for tr in self.transforms.transform[data_type].transforms[::-1]:
sample, metadata = tr.undo_transform(sample, metadata)
return sample, metadata
[docs]class NumpyToTensor(ImedTransform):
"""Converts nd array to tensor object."""
[docs] def __call__(self, sample, metadata=None):
"""Converts nd array to Tensor."""
sample = np.array(sample)
# Use np.ascontiguousarray to avoid axes permutations issues
arr_contig = np.ascontiguousarray(sample, dtype=sample.dtype)
return torch.from_numpy(arr_contig), metadata
[docs]class Resample(ImedTransform):
"""
Resample image to a given resolution.
Args:
hspace (float): Resolution along the first axis, in mm.
wspace (float): Resolution along the second axis, in mm.
dspace (float): Resolution along the third axis, in mm.
interpolation_order (int): Order of spline interpolation. Set to 0 for label data. Default=2.
"""
[docs] def __init__(self, hspace, wspace, dspace=1.):
self.hspace = hspace
self.wspace = wspace
self.dspace = dspace
[docs] @multichannel_capable
@multichannel_capable # for multiple raters during training/preprocessing
@two_dim_compatible
def __call__(self, sample, metadata=None):
"""Resample to a given resolution, in millimeters."""
# Get params
# Voxel dimension in mm
is_2d = sample.shape[-1] == 1
metadata['preresample_shape'] = sample.shape
zooms = list(metadata["zooms"])
if len(zooms) == 2:
zooms += [1.0]
hfactor = zooms[0] / self.hspace
wfactor = zooms[1] / self.wspace
dfactor = zooms[2] / self.dspace
params_resample = (hfactor, wfactor, dfactor) if not is_2d else (hfactor, wfactor, 1.0)
# Run resampling
data_out = zoom(sample,
zoom=params_resample,
order=1 if metadata['data_type'] == 'gt' else 2)
# Data type
data_out = data_out.astype(sample.dtype)
return data_out, metadata
[docs]class NormalizeInstance(ImedTransform):
"""Normalize a tensor or an array image with mean and standard deviation estimated from the sample itself."""
@multichannel_capable
def undo_transform(self, sample, metadata=None):
# Nothing
return sample, metadata
[docs] @multichannel_capable
def __call__(self, sample, metadata=None):
data_out = (sample - sample.mean()) / sample.std()
return data_out, metadata
[docs]class CroppableArray(np.ndarray):
"""Zero padding slice past end of array in numpy.
Adapted From: https://stackoverflow.com/a/41155020/13306686
"""
[docs] def __getitem__(self, item):
all_in_slices = []
pad = []
for dim in range(self.ndim):
# If the slice has no length then it's a single argument.
# If it's just an integer then we just return, this is
# needed for the representation to work properly
# If it's not then create a list containing None-slices
# for dim>=1 and continue down the loop
try:
len(item)
except TypeError:
if isinstance(item, int):
return super().__getitem__(item)
newitem = [slice(None)] * self.ndim
newitem[0] = item
item = newitem
# We're out of items, just append noop slices
if dim >= len(item):
all_in_slices.append(slice(0, self.shape[dim]))
pad.append((0, 0))
# We're dealing with an integer (no padding even if it's
# out of bounds)
if isinstance(item[dim], int):
all_in_slices.append(slice(item[dim], item[dim] + 1))
pad.append((0, 0))
# Dealing with a slice, here it get's complicated, we need
# to correctly deal with None start/stop as well as with
# out-of-bound values and correct padding
elif isinstance(item[dim], slice):
# Placeholders for values
start, stop = 0, self.shape[dim]
this_pad = [0, 0]
if item[dim].start is None:
start = 0
else:
if item[dim].start < 0:
this_pad[0] = -item[dim].start
start = 0
else:
start = item[dim].start
if item[dim].stop is None:
stop = self.shape[dim]
else:
if item[dim].stop > self.shape[dim]:
this_pad[1] = item[dim].stop - self.shape[dim]
stop = self.shape[dim]
else:
stop = item[dim].stop
all_in_slices.append(slice(start, stop))
pad.append(tuple(this_pad))
# Let numpy deal with slicing
ret = super().__getitem__(tuple(all_in_slices))
# and padding
ret = np.pad(ret, tuple(pad), mode='constant', constant_values=0)
return ret
[docs]class Crop(ImedTransform):
"""Crop data.
Args:
size (tuple of int): Size of the output sample. Tuple of size 2 if dealing with 2D samples, 3 with 3D samples.
Attributes:
size (tuple of int): Size of the output sample. Tuple of size 3.
"""
[docs] def __init__(self, size):
self.size = size if len(size) == 3 else size + [0]
@staticmethod
def _adjust_padding(npad, sample):
npad_out_tuple = []
for idx_dim, tuple_pad in enumerate(npad):
pad_start, pad_end = tuple_pad
if pad_start < 0 or pad_end < 0:
# Move axis of interest
sample_reorient = np.swapaxes(sample, 0, idx_dim)
# Adjust pad and crop
if pad_start < 0 and pad_end < 0:
sample_crop = sample_reorient[abs(pad_start):pad_end, ]
pad_end, pad_start = 0, 0
elif pad_start < 0:
sample_crop = sample_reorient[abs(pad_start):, ]
pad_start = 0
else: # i.e. pad_end < 0:
sample_crop = sample_reorient[:pad_end, ]
pad_end = 0
# Reorient
sample = np.swapaxes(sample_crop, 0, idx_dim)
npad_out_tuple.append((pad_start, pad_end))
return npad_out_tuple, sample
[docs] @multichannel_capable
@multichannel_capable # for multiple raters during training/preprocessing
def __call__(self, sample, metadata):
# Get params
is_2d = sample.shape[-1] == 1
th, tw, td = self.size
fh, fw, fd, h, w, d = metadata['crop_params'][self.__class__.__name__]
# Crop data
# Note we use here CroppableArray in order to deal with "out of boundaries" crop
# e.g. if fh is negative or fh+th out of bounds, then it will pad
if is_2d:
data_out = sample.view(CroppableArray)[fh:fh + th, fw:fw + tw, :]
else:
data_out = sample.view(CroppableArray)[fh:fh + th, fw:fw + tw, fd:fd + td]
return data_out, metadata
@multichannel_capable
@two_dim_compatible
def undo_transform(self, sample, metadata=None):
# Get crop params
is_2d = sample.shape[-1] == 1
th, tw, td = self.size
fh, fw, fd, h, w, d = metadata["crop_params"][self.__class__.__name__]
# Compute params to undo transform
pad_left = fw
pad_right = w - pad_left - tw
pad_top = fh
pad_bottom = h - pad_top - th
pad_front = fd if not is_2d else 0
pad_back = d - pad_front - td if not is_2d else 0
npad = [(pad_top, pad_bottom), (pad_left, pad_right), (pad_front, pad_back)]
# Check and adjust npad if needed, i.e. if crop out of boundaries
npad_adj, sample_adj = self._adjust_padding(npad, sample.copy())
# Apply padding
data_out = np.pad(sample_adj,
npad_adj,
mode='constant',
constant_values=0).astype(sample.dtype)
return data_out, metadata
[docs]class CenterCrop(Crop):
"""Make a centered crop of a specified size."""
[docs] @multichannel_capable
@multichannel_capable # for multiple raters during training/preprocessing
@two_dim_compatible
def __call__(self, sample, metadata=None):
# Crop parameters
th, tw, td = self.size
h, w, d = sample.shape
fh = int(round((h - th) / 2.))
fw = int(round((w - tw) / 2.))
fd = int(round((d - td) / 2.))
params = (fh, fw, fd, h, w, d)
metadata['crop_params'][self.__class__.__name__] = params
# Call base method
return super().__call__(sample, metadata)
[docs]class ROICrop(Crop):
"""Make a crop of a specified size around a Region of Interest (ROI)."""
[docs] @multichannel_capable
@multichannel_capable # for multiple raters during training/preprocessing
@two_dim_compatible
def __call__(self, sample, metadata=None):
# If crop_params are not in metadata,
# then we are here dealing with ROI data to determine crop params
if self.__class__.__name__ not in metadata['crop_params']:
# Compute center of mass of the ROI
h_roi, w_roi, d_roi = center_of_mass(sample.astype(np.int))
h_roi, w_roi, d_roi = int(round(h_roi)), int(round(w_roi)), int(round(d_roi))
th, tw, td = self.size
th_half, tw_half, td_half = int(round(th / 2.)), int(round(tw / 2.)), int(round(td / 2.))
# compute top left corner of the crop area
fh = h_roi - th_half
fw = w_roi - tw_half
fd = d_roi - td_half
# Crop params
h, w, d = sample.shape
params = (fh, fw, fd, h, w, d)
metadata['crop_params'][self.__class__.__name__] = params
# Call base method
return super().__call__(sample, metadata)
[docs]class DilateGT(ImedTransform):
"""Randomly dilate a ground-truth tensor.
.. image:: https://raw.githubusercontent.com/ivadomed/doc-figures/main/technical_features/dilate-gt.png
:width: 500px
:align: center
Args:
dilation_factor (float): Controls the number of dilation iterations. For each individual lesion, the number of
dilation iterations is computed as follows:
nb_it = int(round(dilation_factor * sqrt(lesion_area)))
If dilation_factor <= 0, then no dilation will be performed.
"""
[docs] def __init__(self, dilation_factor):
self.dil_factor = dilation_factor
@staticmethod
def dilate_lesion(arr_bin, arr_soft, label_values):
for lb in label_values:
# binary dilation with 1 iteration
arr_dilated = binary_dilation(arr_bin, iterations=1)
# isolate new voxels, i.e. the ones from the dilation
new_voxels = np.logical_xor(arr_dilated, arr_bin).astype(np.int)
# assign a soft value (]0, 1[) to the new voxels
soft_new_voxels = lb * new_voxels
# add the new voxels to the input mask
arr_soft += soft_new_voxels
arr_bin = (arr_soft > 0).astype(np.int)
return arr_bin, arr_soft
def dilate_arr(self, arr, dil_factor):
# identify each object
arr_labeled, lb_nb = label(arr.astype(np.int))
# loop across each object
arr_bin_lst, arr_soft_lst = [], []
for obj_idx in range(1, lb_nb + 1):
arr_bin_obj = (arr_labeled == obj_idx).astype(np.int)
arr_soft_obj = np.copy(arr_bin_obj).astype(np.float)
# compute the number of dilation iterations depending on the size of the lesion
nb_it = int(round(dil_factor * math.sqrt(arr_bin_obj.sum())))
# values of the voxels added to the input mask
soft_label_values = [x / (nb_it + 1) for x in range(nb_it, 0, -1)]
# dilate lesion
arr_bin_dil, arr_soft_dil = self.dilate_lesion(arr_bin_obj, arr_soft_obj, soft_label_values)
arr_bin_lst.append(arr_bin_dil)
arr_soft_lst.append(arr_soft_dil)
# sum dilated objects
arr_bin_idx = np.sum(np.array(arr_bin_lst), axis=0)
arr_soft_idx = np.sum(np.array(arr_soft_lst), axis=0)
# clip values in case dilated voxels overlap
arr_bin_clip, arr_soft_clip = np.clip(arr_bin_idx, 0, 1), np.clip(arr_soft_idx, 0.0, 1.0)
return arr_soft_clip.astype(np.float), arr_bin_clip.astype(np.int)
@staticmethod
def random_holes(arr_in, arr_soft, arr_bin):
arr_soft_out = np.copy(arr_soft)
# coordinates of the new voxels, i.e. the ones from the dilation
new_voxels_xx, new_voxels_yy, new_voxels_zz = np.where(np.logical_xor(arr_bin, arr_in))
nb_new_voxels = new_voxels_xx.shape[0]
# ratio of voxels added to the input mask from the dilated mask
new_voxel_ratio = random.random()
# randomly select new voxel indexes to remove
idx_to_remove = random.sample(range(nb_new_voxels),
int(round(nb_new_voxels * (1 - new_voxel_ratio))))
# set to zero the here-above randomly selected new voxels
arr_soft_out[new_voxels_xx[idx_to_remove],
new_voxels_yy[idx_to_remove],
new_voxels_zz[idx_to_remove]] = 0.0
arr_bin_out = (arr_soft_out > 0).astype(np.int)
return arr_soft_out, arr_bin_out
@staticmethod
def post_processing(arr_in, arr_soft, arr_bin, arr_dil):
# remove new object that are not connected to the input mask
arr_labeled, lb_nb = label(arr_bin)
connected_to_in = arr_labeled * arr_in
for lb in range(1, lb_nb + 1):
if np.sum(connected_to_in == lb) == 0:
arr_soft[arr_labeled == lb] = 0
struct = np.ones((3, 3, 1) if arr_soft.shape[2] == 1 else (3, 3, 3))
# binary closing
arr_bin_closed = binary_closing((arr_soft > 0).astype(np.int), structure=struct)
# fill binary holes
arr_bin_filled = binary_fill_holes(arr_bin_closed)
# recover the soft-value assigned to the filled-holes
arr_soft_out = arr_bin_filled * arr_dil
return arr_soft_out
[docs] @multichannel_capable
@two_dim_compatible
def __call__(self, sample, metadata=None):
# binarize for processing
gt_data_np = (sample > 0.5).astype(np.int_)
if self.dil_factor > 0 and np.sum(sample):
# dilation
gt_dil, gt_dil_bin = self.dilate_arr(gt_data_np, self.dil_factor)
# random holes in dilated area
# gt_holes, gt_holes_bin = self.random_holes(gt_data_np, gt_dil, gt_dil_bin)
# post-processing
# gt_pp = self.post_processing(gt_data_np, gt_holes, gt_holes_bin, gt_dil)
# return gt_pp.astype(np.float32), metadata
return gt_dil.astype(np.float32), metadata
else:
return sample, metadata
[docs]class BoundingBoxCrop(Crop):
"""Crops image according to given bounding box."""
[docs] @multichannel_capable
@two_dim_compatible
def __call__(self, sample, metadata):
assert 'bounding_box' in metadata
x_min, x_max, y_min, y_max, z_min, z_max = metadata['bounding_box']
x, y, z = sample.shape
metadata['crop_params'][self.__class__.__name__] = (x_min, y_min, z_min, x, y, z)
# Call base method
return super().__call__(sample, metadata)
[docs]class RandomAffine(ImedTransform):
"""Apply Random Affine transformation.
Args:
degrees (float): Positive float or list (or tuple) of length two. Angles in degrees. If only a float is
provided, then rotation angle is selected within the range [-degrees, degrees]. Otherwise, the list / tuple
defines this range.
translate (list of float): List of floats between 0 and 1, of length 2 or 3 depending on the sample shape (2D
or 3D). These floats defines the maximum range of translation along each axis.
scale (list of float): List of floats between 0 and 1, of length 2 or 3 depending on the sample shape (2D
or 3D). These floats defines the maximum range of scaling along each axis.
Attributes:
degrees (tuple of floats):
translate (list of float):
scale (list of float):
"""
[docs] def __init__(self, degrees=0, translate=None, scale=None):
# Rotation
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
# Scale
if scale is not None:
assert isinstance(scale, (tuple, list)) and (len(scale) == 2 or len(scale) == 3), \
"scale should be a list or tuple and it must be of length 2 or 3."
for s in scale:
if not (0.0 <= s <= 1.0):
raise ValueError("scale values should be between 0 and 1")
if len(scale) == 2:
scale.append(0.0)
self.scale = scale
else:
self.scale = [0., 0., 0.]
# Translation
if translate is not None:
assert isinstance(translate, (tuple, list)) and (len(translate) == 2 or len(translate) == 3), \
"translate should be a list or tuple and it must be of length 2 or 3."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
if len(translate) == 2:
translate.append(0.0)
self.translate = translate
[docs] @multichannel_capable
@two_dim_compatible
def __call__(self, sample, metadata=None):
# Rotation
# If angle and metadata have been already defined for this sample, then use them
if 'rotation' in metadata:
angle, axes = metadata['rotation']
# Otherwise, get random ones
else:
# Get the random angle
angle = math.radians(np.random.uniform(self.degrees[0], self.degrees[1]))
# Get the two axes that define the plane of rotation
axes = list(random.sample(range(3 if sample.shape[2] > 1 else 2), 2))
axes.sort()
# Save params
metadata['rotation'] = [angle, axes]
# Scale
if "scale" in metadata:
scale_x, scale_y, scale_z = metadata['scale']
else:
scale_x = random.uniform(1 - self.scale[0], 1 + self.scale[0])
scale_y = random.uniform(1 - self.scale[1], 1 + self.scale[1])
scale_z = random.uniform(1 - self.scale[2], 1 + self.scale[2])
metadata['scale'] = [scale_x, scale_y, scale_z]
# Get params
if 'translation' in metadata:
translations = metadata['translation']
else:
self.data_shape = sample.shape
if self.translate is not None:
max_dx = self.translate[0] * self.data_shape[0]
max_dy = self.translate[1] * self.data_shape[1]
max_dz = self.translate[2] * self.data_shape[2]
translations = (np.round(np.random.uniform(-max_dx, max_dx)),
np.round(np.random.uniform(-max_dy, max_dy)),
np.round(np.random.uniform(-max_dz, max_dz)))
else:
translations = (0, 0, 0)
metadata['translation'] = translations
# Do rotation
shape = 0.5 * np.array(sample.shape)
if axes == [0, 1]:
rotate = np.array([[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
[0, 0, 1]])
elif axes == [0, 2]:
rotate = np.array([[math.cos(angle), 0, math.sin(angle)],
[0, 1, 0],
[-math.sin(angle), 0, math.cos(angle)]])
elif axes == [1, 2]:
rotate = np.array([[1, 0, 0],
[0, math.cos(angle), -math.sin(angle)],
[0, math.sin(angle), math.cos(angle)]])
else:
raise ValueError("Unknown axes value")
scale = np.array([[1 / scale_x, 0, 0], [0, 1 / scale_y, 0], [0, 0, 1 / scale_z]])
if "undo" in metadata and metadata["undo"]:
transforms = scale.dot(rotate)
else:
transforms = rotate.dot(scale)
offset = shape - shape.dot(transforms) + translations
data_out = affine_transform(sample, transforms.T, order=1, offset=offset,
output_shape=sample.shape).astype(sample.dtype)
return data_out, metadata
@multichannel_capable
@two_dim_compatible
def undo_transform(self, sample, metadata=None):
assert "rotation" in metadata
assert "scale" in metadata
assert "translation" in metadata
# Opposite rotation, same axes
angle, axes = - metadata['rotation'][0], metadata['rotation'][1]
scale = 1 / np.array(metadata['scale'])
translation = - np.array(metadata['translation'])
# Undo rotation
dict_params = {"rotation": [angle, axes], "scale": scale, "translation": [0, 0, 0], "undo": True}
data_out, _ = self.__call__(sample, dict_params)
data_out = affine_transform(data_out, np.identity(3), order=1, offset=translation,
output_shape=sample.shape).astype(sample.dtype)
return data_out, metadata
[docs]class RandomReverse(ImedTransform):
"""Make a randomized symmetric inversion of the different values of each dimensions."""
[docs] @multichannel_capable
@two_dim_compatible
def __call__(self, sample, metadata=None):
if 'reverse' in metadata:
flip_axes = metadata['reverse']
else:
# Flip axis booleans
flip_axes = [np.random.randint(2) == 1 for _ in [0, 1, 2]]
# Save in metadata
metadata['reverse'] = flip_axes
# Run flip
for idx_axis, flip_bool in enumerate(flip_axes):
if flip_axes:
sample = np.flip(sample, axis=idx_axis).copy()
return sample, metadata
@multichannel_capable
@two_dim_compatible
def undo_transform(self, sample, metadata=None):
assert "reverse" in metadata
return self.__call__(sample, metadata)
[docs]class RandomShiftIntensity(ImedTransform):
"""Add a random intensity offset.
Args:
shift_range (tuple of floats): Tuple of length two. Specifies the range where the offset that is applied is
randomly selected from.
prob (float): Between 0 and 1. Probability of occurence of this transformation.
"""
[docs] def __init__(self, shift_range, prob=0.1):
self.shift_range = shift_range
self.prob = prob
[docs] @multichannel_capable
def __call__(self, sample, metadata=None):
if np.random.random() < self.prob:
# Get random offset
offset = np.random.uniform(self.shift_range[0], self.shift_range[1])
else:
offset = 0.0
# Update metadata
metadata['offset'] = offset
# Shift intensity
data = (sample + offset).astype(sample.dtype)
return data, metadata
@multichannel_capable
def undo_transform(self, sample, metadata=None):
assert 'offset' in metadata
# Get offset
offset = metadata['offset']
# Substract offset
data = (sample - offset).astype(sample.dtype)
return data, metadata
[docs]class AdditiveGaussianNoise(ImedTransform):
"""Adds Gaussian Noise to images.
Args:
mean (float): Gaussian noise mean.
std (float): Gaussian noise standard deviation.
"""
[docs] def __init__(self, mean=0.0, std=0.01):
self.mean = mean
self.std = std
[docs] @multichannel_capable
def __call__(self, sample, metadata=None):
if "gaussian_noise" in metadata:
noise = metadata["gaussian_noise"]
else:
# Get random noise
noise = np.random.normal(self.mean, self.std, sample.shape)
noise = noise.astype(np.float32)
# Apply noise
data_out = sample + noise
return data_out.astype(sample.dtype), metadata
[docs]class Clahe(ImedTransform):
""" Applies Contrast Limited Adaptive Histogram Equalization for enhancing the local image contrast.
.. seealso::
Zuiderveld, Karel. "Contrast limited adaptive histogram equalization." Graphics gems (1994): 474-485.
Default values are based on:
.. seealso::
Zheng, Qiao, et al. "3-D consistent and robust segmentation of cardiac images by deep learning with spatial
propagation." IEEE transactions on medical imaging 37.9 (2018): 2137-2148.
Args:
clip_limit (float): Clipping limit, normalized between 0 and 1.
kernel_size (tuple of int): Defines the shape of contextual regions used in the algorithm. Length equals image
dimension (ie 2 or 3 for 2D or 3D, respectively).
"""
[docs] def __init__(self, clip_limit=3.0, kernel_size=(8, 8)):
self.clip_limit = clip_limit
self.kernel_size = kernel_size
[docs] @multichannel_capable
def __call__(self, sample, metadata=None):
assert len(self.kernel_size) == len(sample.shape)
# Run equalization
data_out = equalize_adapthist(sample,
kernel_size=self.kernel_size,
clip_limit=self.clip_limit).astype(sample.dtype)
return data_out, metadata
[docs]class HistogramClipping(ImedTransform):
"""Performs intensity clipping based on percentiles.
Args:
min_percentile (float): Between 0 and 100. Lower clipping limit.
max_percentile (float): Between 0 and 100. Higher clipping limit.
"""
[docs] def __init__(self, min_percentile=5.0, max_percentile=95.0):
self.min_percentile = min_percentile
self.max_percentile = max_percentile
[docs] @multichannel_capable
def __call__(self, sample, metadata=None):
data = np.copy(sample)
# Run clipping
percentile1 = np.percentile(sample, self.min_percentile)
percentile2 = np.percentile(sample, self.max_percentile)
data[sample <= percentile1] = percentile1
data[sample >= percentile2] = percentile2
return data, metadata