Source code for ivadomed.losses

import torch
import torch.nn as nn
import scipy
import scipy.ndimage
import numpy as np


[docs] class MultiClassDiceLoss(nn.Module): """Multi-class Dice Loss. Inspired from https://arxiv.org/pdf/1802.10508. Args: classes_of_interest (list): List containing the index of a class which its dice will be added to the loss. If is None all classes are considered. Attributes: classes_of_interest (list): List containing the index of a class which its dice will be added to the loss. If is None all classes are considered. dice_loss (DiceLoss): Class computing the Dice loss. """
[docs] def __init__(self, classes_of_interest=None): super(MultiClassDiceLoss, self).__init__() self.classes_of_interest = classes_of_interest self.dice_loss = DiceLoss()
[docs] def forward(self, prediction, target): dice_per_class = 0 n_classes = prediction.shape[1] if self.classes_of_interest is None: self.classes_of_interest = range(n_classes) for i in self.classes_of_interest: dice_per_class += self.dice_loss(prediction[:, i, ], target[:, i, ]) return dice_per_class / len(self.classes_of_interest)
[docs] class DiceLoss(nn.Module): """DiceLoss. .. seealso:: Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." 2016 fourth international conference on 3D vision (3DV). IEEE, 2016. Args: smooth (float): Value to avoid division by zero when images and predictions are empty. Attributes: smooth (float): Value to avoid division by zero when images and predictions are empty. """
[docs] def __init__(self, smooth=1.0): super(DiceLoss, self).__init__() self.smooth = smooth
[docs] def forward(self, prediction, target): iflat = prediction.reshape(-1) tflat = target.reshape(-1) intersection = (iflat * tflat).sum() return - (2.0 * intersection + self.smooth) / (iflat.sum() + tflat.sum() + self.smooth)
[docs] class BinaryCrossEntropyLoss(nn.Module): """(`BinaryCrossEntropyLoss <https://pytorch.org/docs/master/generated/torch.nn.BCELoss.html#bceloss>`__). Attributes: loss_fct (BCELoss): Binary cross entropy loss function from torch library. """
[docs] def __init__(self): super(BinaryCrossEntropyLoss, self).__init__() self.loss_fct = nn.BCELoss()
[docs] def forward(self, prediction, target): return self.loss_fct(prediction, target.float())
[docs] class FocalLoss(nn.Module): """FocalLoss. .. seealso:: Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017. Args: gamma (float): Value from 0 to 5, Control between easy background and hard ROI training examples. If set to 0, equivalent to cross-entropy. alpha (float): Value from 0 to 1, usually corresponding to the inverse of class frequency to address class imbalance. eps (float): Epsilon to avoid division by zero. Attributes: gamma (float): Value from 0 to 5, Control between easy background and hard ROI training examples. If set to 0, equivalent to cross-entropy. alpha (float): Value from 0 to 1, usually corresponding to the inverse of class frequency to address class imbalance. eps (float): Epsilon to avoid division by zero. """
[docs] def __init__(self, gamma=2, alpha=0.25, eps=1e-7): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.eps = eps
[docs] def forward(self, input, target): input = input.clamp(self.eps, 1. - self.eps) cross_entropy = - (target * torch.log(input) + (1 - target) * torch.log(1 - input)) # eq1 logpt = - cross_entropy pt = torch.exp(logpt) # eq2 at = self.alpha * target + (1 - self.alpha) * (1 - target) balanced_cross_entropy = - at * logpt # eq3 focal_loss = balanced_cross_entropy * ((1 - pt) ** self.gamma) # eq5 return focal_loss.sum()
# return focal_loss.mean()
[docs] class FocalDiceLoss(nn.Module): """FocalDiceLoss. .. seealso:: Wong, Ken CL, et al. "3D segmentation with exponential logarithmic loss for highly unbalanced object sizes." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018. Args: beta (float): Value from 0 to 1, indicating the weight of the dice loss. gamma (float): Value from 0 to 5, Control between easy background and hard ROI training examples. If set to 0, equivalent to cross-entropy. alpha (float): Value from 0 to 1, usually corresponding to the inverse of class frequency to address class imbalance. Attributes: beta (float): Value from 0 to 1, indicating the weight of the dice loss. gamma (float): Value from 0 to 5, Control between easy background and hard ROI training examples. If set to 0, equivalent to cross-entropy. alpha (float): Value from 0 to 1, usually corresponding to the inverse of class frequency to address class imbalance. """
[docs] def __init__(self, beta=1, gamma=2, alpha=0.25): super().__init__() self.beta = beta self.focal = FocalLoss(gamma, alpha) self.dice = DiceLoss()
[docs] def forward(self, input, target): dc_loss = - self.dice(input, target) fc_loss = self.focal(input, target) # used to fine tune beta # with torch.no_grad(): # print('DICE loss:', dc_loss.cpu().numpy(), 'Focal loss:', fc_loss.cpu().numpy()) # log_dc_loss = torch.log(torch.clamp(dc_loss, 1e-7)) # log_fc_loss = torch.log(torch.clamp(fc_loss, 1e-7)) # print('Log DICE loss:', log_dc_loss.cpu().numpy(), 'Log Focal loss:', log_fc_loss.cpu().numpy()) # print('*'*20) loss = torch.log(torch.clamp(fc_loss, 1e-7)) - self.beta * torch.log(torch.clamp(dc_loss, 1e-7)) return loss
[docs] class GeneralizedDiceLoss(nn.Module): """GeneralizedDiceLoss. .. seealso:: Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations." Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2017. 240-248. Args: epsilon (float): Epsilon to avoid division by zero. include_background (bool): If True, then an extra channel is added, which represents the background class. Attributes: epsilon (float): Epsilon to avoid division by zero. include_background (bool): If True, then an extra channel is added, which represents the background class. """
[docs] def __init__(self, epsilon=1e-5, include_background=True): super(GeneralizedDiceLoss, self).__init__() self.epsilon = epsilon self.include_background = include_background
[docs] def forward(self, input, target): if not (target.size() == input.size()): raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) if self.include_background: # init size_background = [input.size(0), 1] + list(input.size())[2:] input_background = torch.zeros(size_background, dtype=input.dtype) target_background = torch.zeros(size_background, dtype=target.dtype) # fill with opposite input_background[input.sum(1)[:, None, :, :] == 0] = 1 target_background[target.sum(1)[:, None, :, :] == 0] = 1 # Concat input = torch.cat([input, input_background.to(input.device)], dim=1) target = torch.cat([target, target_background.to(target.device)], dim=1) # Compute class weights target = target.float() axes_to_sum = tuple(range(2, len(target.shape))) target_sum = target.sum(axis=axes_to_sum) class_weights = nn.Parameter(1. / (target_sum * target_sum).clamp(min=self.epsilon)) # W Intersection intersect = ((input * target).sum(axis=axes_to_sum) * class_weights).sum() # W Union denominator = ((input + target).sum(axis=axes_to_sum) * class_weights).sum() return - 2. * intersect / denominator.clamp(min=self.epsilon)
[docs] class TverskyLoss(nn.Module): """Tversky Loss. .. seealso:: Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function for image segmentation using 3D fully convolutional deep networks." International Workshop on Machine Learning in Medical Imaging. Springer, Cham, 2017. Args: alpha (float): Weight of false positive voxels. beta (float): Weight of false negative voxels. smooth (float): Epsilon to avoid division by zero, when both Numerator and Denominator of Tversky are zeros. Attributes: alpha (float): Weight of false positive voxels. beta (float): Weight of false negative voxels. smooth (float): Epsilon to avoid division by zero, when both Numerator and Denominator of Tversky are zeros. Notes: - setting alpha=beta=0.5: Equivalent to DiceLoss. - default parameters were suggested by https://arxiv.org/pdf/1706.05721.pdf . """
[docs] def __init__(self, alpha=0.7, beta=0.3, smooth=1.0): super(TverskyLoss, self).__init__() self.alpha = alpha self.beta = beta self.smooth = smooth
[docs] def tversky_index(self, y_pred, y_true): """Compute Tversky index. Args: y_pred (torch Tensor): Prediction. y_true (torch Tensor): Target. Returns: float: Tversky index. """ # Compute TP y_true = y_true.float() tp = torch.sum(y_true * y_pred) # Compute FN fn = torch.sum(y_true * (1 - y_pred)) # Compute FP fp = torch.sum((1 - y_true) * y_pred) # Compute Tversky for the current class, see Equation 3 of the original paper numerator = tp + self.smooth denominator = tp + self.alpha * fp + self.beta * fn + self.smooth tversky_label = numerator / denominator return tversky_label
[docs] def forward(self, input, target): n_classes = input.shape[1] tversky_sum = 0. # TODO: Add class_of_interest? for i_label in range(n_classes): # Get samples for a given class y_pred, y_true = input[:, i_label, ], target[:, i_label, ] # Compute Tversky index tversky_sum += self.tversky_index(y_pred, y_true) return - tversky_sum / n_classes
[docs] class FocalTverskyLoss(TverskyLoss): """Focal Tversky Loss. .. seealso:: Abraham, Nabila, and Naimul Mefraz Khan. "A novel focal tversky loss function with improved attention u-net for lesion segmentation." 2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019). IEEE, 2019. Args: alpha (float): Weight of false positive voxels. beta (float): Weight of false negative voxels. gamma (float): Typically between 1 and 3. Control between easy background and hard ROI training examples. smooth (float): Epsilon to avoid division by zero, when both Numerator and Denominator of Tversky are zeros. Attributes: gamma (float): Typically between 1 and 3. Control between easy background and hard ROI training examples. Notes: - setting alpha=beta=0.5 and gamma=1: Equivalent to DiceLoss. - default parameters were suggested by https://arxiv.org/pdf/1810.07842.pdf . """
[docs] def __init__(self, alpha=0.7, beta=0.3, gamma=1.33, smooth=1.0): super(FocalTverskyLoss, self).__init__() self.gamma = gamma self.tversky = TverskyLoss(alpha=alpha, beta=beta, smooth=smooth)
[docs] def forward(self, input, target): n_classes = input.shape[1] focal_tversky_sum = 0. # TODO: Add class_of_interest? for i_label in range(n_classes): # Get samples for a given class y_pred, y_true = input[:, i_label, ], target[:, i_label, ] # Compute Tversky index tversky_index = self.tversky.tversky_index(y_pred, y_true) # Compute Focal Tversky loss, Equation 4 in the original paper focal_tversky_sum += torch.pow(1 - tversky_index, exponent=1 / self.gamma) return focal_tversky_sum / n_classes
[docs] class L2loss(nn.Module): """ Euclidean loss also known as L2 loss. Compute the sum of the squared difference between the two images. """
[docs] def __init__(self): super(L2loss, self).__init__()
[docs] def forward(self, input, target): return torch.sum((input - target) ** 2) / 2
[docs] class AdapWingLoss(nn.Module): """ Adaptive Wing loss Used for heatmap ground truth. .. seealso:: Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." Proceedings of the IEEE International Conference on Computer Vision. 2019. Args: theta (float): Threshold between linear and non linear loss. alpha (float): Used to adapt loss shape to input shape and make loss smooth at 0 (background). It needs to be slightly above 2 to maintain ideal properties. omega (float): Multiplicating factor for non linear part of the loss. epsilon (float): factor to avoid gradient explosion. It must not be too small """
[docs] def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1): self.theta = theta self.alpha = alpha self.omega = omega self.epsilon = epsilon super(AdapWingLoss, self).__init__()
[docs] def forward(self, input, target): eps = self.epsilon # Compute adaptative factor A = self.omega * (1 / (1 + torch.pow(self.theta / eps, self.alpha - target))) * \ (self.alpha - target) * torch.pow(self.theta / eps, self.alpha - target - 1) * (1 / eps) # Constant term to link linear and non linear part C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) batch_size = target.size()[0] hm_num = target.size()[1] mask = torch.zeros_like(target) kernel = scipy.ndimage.generate_binary_structure(2, 2) # For 3D segmentation tasks if len(input.shape) == 5: kernel = scipy.ndimage.generate_binary_structure(3, 2) for i in range(batch_size): img_list = list() img_list.append(np.round(target[i].cpu().numpy() * 255)) img_merge = np.concatenate(img_list) img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) img_dilate[img_dilate < 51] = 1 # 0*omega+1 img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 img_dilate = np.array(img_dilate, dtype=int) mask[i] = torch.tensor(img_dilate) diff_hm = torch.abs(target - input) AWingLoss = A * diff_hm - C idx = diff_hm < self.theta AWingLoss[idx] = self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target))[idx] AWingLoss *= mask sum_loss = torch.sum(AWingLoss) all_pixel = torch.sum(mask) mean_loss = sum_loss # / all_pixel return mean_loss
[docs] class LossCombination(nn.Module): """ Loss that sums other implemented losses. Args: losses_list (list): list of losses that will be summed. Elements should be string. params_list (list): list of params for the losses, contain None or dictionnary definition of params for the loss at same index. If no params list is given all default parameter will be used. (e.g., losses_list = ["L2loss","DiceLoss"] params_list = [None,{"param1:0.5"}]) returns: tensor: sum of losses computed on (input,target) with the params """
[docs] def __init__(self, losses_list, params_list=None): self.losses_list = losses_list self.params_list = params_list super(LossCombination, self).__init__()
[docs] def forward(self, input, target): output = [] for i in range(len(self.losses_list)): loss_class = eval(self.losses_list[i]) if self.params_list is not None: if self.params_list[i] is not None: loss_fct = loss_class(**self.params_list[i]) else: loss_fct = loss_class() output.append(loss_fct(input, target).unsqueeze(0)) else: output.append(loss_class()(input, target).unsqueeze(0)) return torch.sum(torch.cat(output))