Source code for ivadomed.models

from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch.nn import init
from pathlib import Path
import torchvision.models

#Modified from torchvision.models.resnet.Resnet
[docs] class ResNet(nn.Module): """ResNet model based on `"Deep Residual Learning for Image Recognition" <>`_ Args: block (nn.Module): Basic block of the network (such as conv + bn + non-nonlinearity) layers (int list): Number of blocks to stack (network depth) after each downsampling step. num_classes (int): Number of GT classes. zero_init_residual (bool): if True, zero-initialize the last BN in each residual branch, so that the residual branch starts with zeros, and each residual block behaves like an identity. groups (int): Number of parallel branches in the network, see `"paper" <>`_ width_per_group (int): base width of the blocks replace_stride_with_dilation (tuple): each element in the tuple indicates if we replace the 2x2 stride with a dilated convolution norm_layer (layer): Custom normalization layer, if not provided BatchNorm2d is used """
[docs] def __init__(self, block, layers, num_classes=2, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to if zero_init_residual: for m in self.modules(): if isinstance(m, torchvision.models.resnet.Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, torchvision.models.resnet.BasicBlock): nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( torchvision.models.resnet.conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) preds = F.softmax(x, dim=1) # Remove background class preds = preds[:, 1:] return preds
[docs] def forward(self, x): return self._forward_impl(x)
class resnet18(ResNet): def __init__(self, **kwargs): super().__init__(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2])
[docs] class DenseNet(nn.Module): """Densenet-BC model class, based on `"Densely Connected Convolutional Networks" <>`_ Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) dropout_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"article" <>`_ """
[docs] def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, dropout_rate=0.3, num_classes=2, memory_efficient=False): super(DenseNet, self).__init__() # First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(1, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ])) # Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = torchvision.models.densenet._DenseBlock( num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=dropout_rate, memory_efficient=memory_efficient ) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = torchvision.models.densenet._Transition(num_input_features=num_features, num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2 # Final batch norm self.features.add_module('norm5', nn.BatchNorm2d(num_features)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0)
[docs] def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) out = self.classifier(out) preds = F.softmax(out, dim=1) # Remove background class preds = preds[:, 1:] return preds
class densenet121(DenseNet): def __init__(self, **kwargs): super().__init__(32, (6, 12, 24, 16), 64) class DownConv(Module): """Two successive series of down convolution, batch normalization and dropout in 2D. Used in U-Net's encoder. Args: in_feat (int): Number of channels in the input image. out_feat (int): Number of channels in the output image. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions. Attributes: conv1 (Conv2d): First 2D down convolution with kernel size 3 and padding of 1. conv1_bn (BatchNorm2d): First 2D batch normalization. conv1_drop (Dropout2d): First 2D dropout. conv2 (Conv2d): Second 2D down convolution with kernel size 3 and padding of 1. conv2_bn (BatchNorm2d): Second 2D batch normalization. conv2_drop (Dropout2d): Second 2D dropout. """ def __init__(self, in_feat, out_feat, dropout_rate=0.3, bn_momentum=0.1, is_2d=True): super(DownConv, self).__init__() if is_2d: conv = nn.Conv2d bn = nn.BatchNorm2d dropout = nn.Dropout2d else: conv = nn.Conv3d bn = nn.InstanceNorm3d dropout = nn.Dropout3d self.conv1 = conv(in_feat, out_feat, kernel_size=3, padding=1) self.conv1_bn = bn(out_feat, momentum=bn_momentum) self.conv1_drop = dropout(dropout_rate) self.conv2 = conv(out_feat, out_feat, kernel_size=3, padding=1) self.conv2_bn = bn(out_feat, momentum=bn_momentum) self.conv2_drop = dropout(dropout_rate) def forward(self, x): x = F.relu(self.conv1(x)) x = self.conv1_bn(x) x = self.conv1_drop(x) x = F.relu(self.conv2(x)) x = self.conv2_bn(x) x = self.conv2_drop(x) return x class UpConv(Module): """2D down convolution. Used in U-Net's decoder. Args: in_feat (int): Number of channels in the input image. out_feat (int): Number of channels in the output image. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions. Attributes: downconv (DownConv): Down convolution. """ def __init__(self, in_feat, out_feat, dropout_rate=0.3, bn_momentum=0.1, is_2d=True): super(UpConv, self).__init__() self.is_2d = is_2d self.downconv = DownConv(in_feat, out_feat, dropout_rate, bn_momentum, is_2d) def forward(self, x, y): # For retrocompatibility purposes if not hasattr(self, "is_2d"): self.is_2d = True mode = 'bilinear' if self.is_2d else 'trilinear' dims = -2 if self.is_2d else -3 x = F.interpolate(x, size=y.size()[dims:], mode=mode, align_corners=True) x =[x, y], dim=1) x = self.downconv(x) return x class Encoder(Module): """Encoding part of the U-Net model. It returns the features map for the skip connections Args: in_channel (int): Number of channels in the input image. depth (int): Number of down convolutions minus bottom down convolution. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. n_metadata (dict): FiLM metadata see for more details. film_layers (list): List of 0 or 1 indicating on which layer FiLM is applied. is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions. n_filters (int): Number of base filters in the U-Net. Attributes: depth (int): Number of down convolutions minus bottom down convolution. down_path (ModuleList): List of module operations done during encoding. conv_bottom (DownConv): Bottom down convolution. film_bottom (FiLMlayer): FiLM layer applied to bottom convolution. """ def __init__(self, in_channel=1, depth=3, dropout_rate=0.3, bn_momentum=0.1, n_metadata=None, film_layers=None, is_2d=True, n_filters=64): super(Encoder, self).__init__() self.depth = depth self.down_path = nn.ModuleList() # first block self.down_path.append(DownConv(in_channel, n_filters, dropout_rate, bn_momentum, is_2d)) self.down_path.append(FiLMlayer(n_metadata, n_filters) if film_layers and film_layers[0] else None) max_pool = nn.MaxPool2d if is_2d else nn.MaxPool3d self.down_path.append(max_pool(2)) for i in range(depth - 1): self.down_path.append(DownConv(n_filters, n_filters * 2, dropout_rate, bn_momentum, is_2d)) self.down_path.append(FiLMlayer(n_metadata, n_filters * 2) if film_layers and film_layers[i + 1] else None) self.down_path.append(max_pool(2)) n_filters = n_filters * 2 # Bottom self.conv_bottom = DownConv(n_filters, n_filters, dropout_rate, bn_momentum, is_2d) self.film_bottom = FiLMlayer(n_metadata, n_filters) if film_layers and film_layers[self.depth] else None def forward(self, x, context=None): features = [] # First block x = self.down_path[0](x) if self.down_path[1]: x, w_film = self.down_path[1](x, context, None) features.append(x) x = self.down_path[2](x) # Down-sampling path (other blocks) for i in range(1, self.depth): x = self.down_path[i * 3](x) if self.down_path[i * 3 + 1]: x, w_film = self.down_path[i * 3 + 1](x, context, None if 'w_film' not in locals() else w_film) features.append(x) x = self.down_path[i * 3 + 2](x) # Bottom level x = self.conv_bottom(x) if self.film_bottom: x, w_film = self.film_bottom(x, context, None if 'w_film' not in locals() else w_film) features.append(x) return features, None if 'w_film' not in locals() else w_film class Decoder(Module): """Decoding part of the U-Net model. Args: out_channel (int): Number of channels in the output image. depth (int): Number of down convolutions minus bottom down convolution. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. n_metadata (dict): FiLM metadata see for more details. film_layers (list): List of 0 or 1 indicating on which layer FiLM is applied. hemis (bool): Boolean indicating if HeMIS is on or not. final_activation (str): Choice of final activation between "sigmoid", "relu" and "softmax". is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions. n_filters (int): Number of base filters in the U-Net. Attributes: depth (int): Number of down convolutions minus bottom down convolution. out_channel (int): Number of channels in the output image. up_path (ModuleList): List of module operations done during decoding. last_conv (Conv2d): Last convolution. last_film (FiLMlayer): FiLM layer applied to last convolution. softmax (Softmax): Softmax layer that can be applied as last layer. """ def __init__(self, out_channel=1, depth=3, dropout_rate=0.3, bn_momentum=0.1, n_metadata=None, film_layers=None, hemis=False, final_activation="sigmoid", is_2d=True, n_filters=64): super(Decoder, self).__init__() self.depth = depth self.out_channel = out_channel self.final_activation = final_activation # Up-Sampling path self.up_path = nn.ModuleList() if hemis: in_channel = n_filters * 2 ** self.depth self.up_path.append(UpConv(in_channel * 2, n_filters * 2 ** (self.depth - 1), dropout_rate, bn_momentum, is_2d)) if film_layers and film_layers[self.depth + 1]: self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - 1))) else: self.up_path.append(None) # self.depth += 1 else: in_channel = n_filters * 2 ** self.depth self.up_path.append(UpConv(in_channel, n_filters * 2 ** (self.depth - 1), dropout_rate, bn_momentum, is_2d)) if film_layers and film_layers[self.depth + 1]: self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - 1))) else: self.up_path.append(None) for i in range(1, depth): in_channel //= 2 self.up_path.append( UpConv(in_channel + n_filters * 2 ** (self.depth - i - 1 + int(hemis)), n_filters * 2 ** (self.depth - i - 1), dropout_rate, bn_momentum, is_2d)) if film_layers and film_layers[self.depth + i + 1]: self.up_path.append(FiLMlayer(n_metadata, n_filters * 2 ** (self.depth - i - 1))) else: self.up_path.append(None) # Last Convolution conv = nn.Conv2d if is_2d else nn.Conv3d self.last_conv = conv(in_channel // 2, out_channel, kernel_size=3, padding=1) self.last_film = FiLMlayer(n_metadata, self.out_channel) if film_layers and film_layers[-1] else None self.softmax = nn.Softmax(dim=1) def forward(self, features, context=None, w_film=None): x = features[-1] for i in reversed(range(self.depth)): x = self.up_path[-(i + 1) * 2](x, features[i]) if self.up_path[-(i + 1) * 2 + 1]: x, w_film = self.up_path[-(i + 1) * 2 + 1](x, context, w_film) # Last convolution x = self.last_conv(x) if self.last_film: x, w_film = self.last_film(x, context, w_film) if hasattr(self, "final_activation") and self.final_activation not in ["softmax", "relu", "sigmoid"]: raise ValueError("final_activation value has to be either softmax, relu, or sigmoid") elif hasattr(self, "final_activation") and self.final_activation == "softmax": preds = self.softmax(x) elif hasattr(self, "final_activation") and self.final_activation == "relu": preds = nn.ReLU()(x) / nn.ReLU()(x).max() # If nn.ReLU()(x).max()==0, then nn.ReLU()(x) will also ==0. So, here any zero division will always be 0/0. # For float values, 0/0=nan. So, we can handle zero division (without checking data!) by setting nans to 0. preds[torch.isnan(preds)] = 0. # If model multiclass if self.out_channel > 1: class_sum = preds.sum(dim=1).unsqueeze(1) # Avoid division by zero class_sum[class_sum == 0] = 1 preds /= class_sum else: preds = torch.sigmoid(x) # If model multiclass if self.out_channel > 1: # Remove background class preds = preds[:, 1:, ] return preds
[docs] class Unet(Module): """A reference U-Net model. .. seealso:: Ronneberger, O., et al (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation ArXiv link: Args: in_channel (int): Number of channels in the input image. out_channel (int): Number of channels in the output image. depth (int): Number of down convolutions minus bottom down convolution. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. final_activation (str): Choice of final activation between "sigmoid", "relu" and "softmax". is_2d (bool): Indicates dimensionality of model: True for 2D convolutions, False for 3D convolutions. n_filters (int): Number of base filters in the U-Net. **kwargs: Attributes: encoder (Encoder): U-Net encoder. decoder (Decoder): U-net decoder. """
[docs] def __init__(self, in_channel=1, out_channel=1, depth=3, dropout_rate=0.3, bn_momentum=0.1, final_activation='sigmoid', is_2d=True, n_filters=64, **kwargs): super(Unet, self).__init__() # Encoder path self.encoder = Encoder(in_channel=in_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum, is_2d=is_2d, n_filters=n_filters) # Decoder path self.decoder = Decoder(out_channel=out_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum, final_activation=final_activation, is_2d=is_2d, n_filters=n_filters)
[docs] def forward(self, x): features, _ = self.encoder(x) preds = self.decoder(features) return preds
[docs] class FiLMedUnet(Unet): """U-Net network containing FiLM layers to condition the model with another data type (i.e. not an image). Args: n_channel (int): Number of channels in the input image. out_channel (int): Number of channels in the output image. depth (int): Number of down convolutions minus bottom down convolution. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. n_metadata (dict): FiLM metadata see for more details. film_layers (list): List of 0 or 1 indicating on which layer FiLM is applied. is_2d (bool): Indicates dimensionality of model. n_filters (int): Number of base filters in the U-Net. **kwargs: Attributes: encoder (Encoder): U-Net encoder. decoder (Decoder): U-net decoder. """
[docs] def __init__(self, in_channel=1, out_channel=1, depth=3, dropout_rate=0.3, bn_momentum=0.1, n_metadata=None, film_layers=None, is_2d=True, n_filters=64, **kwargs): super().__init__(in_channel=in_channel, out_channel=out_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum) # Verify if the length of boolean FiLM layers corresponds to the depth if film_layers: if len(film_layers) != 2 * depth + 2: raise ValueError("The number of FiLM layers {} entered does not correspond to the " "UNet depth. There should 2 * depth + 2 layers.".format(len(film_layers))) else: film_layers = [0] * (2 * depth + 2) # Encoder path self.encoder = Encoder(in_channel=in_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum, n_metadata=n_metadata, film_layers=film_layers, is_2d=is_2d, n_filters=n_filters) # Decoder path self.decoder = Decoder(out_channel=out_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum, n_metadata=n_metadata, film_layers=film_layers, is_2d=is_2d, n_filters=n_filters)
[docs] def forward(self, x, context=None): features, w_film = self.encoder(x, context) preds = self.decoder(features, context, w_film) return preds
class FiLMgenerator(Module): """The FiLM generator processes the conditioning information and produces parameters that describe how the target network should alter its computation. Here, the FiLM generator is a multi-layer perceptron. Args: n_features (int): Number of input channels. n_channels (int): Number of output channels. n_hid (int): Number of hidden units in layer. Attributes: linear1 (Linear): Input linear layer. sig (Sigmoid): Sigmoid function. linear2 (Linear): Hidden linear layer. linear3 (Linear): Output linear layer. """ def __init__(self, n_features, n_channels, n_hid=64): super(FiLMgenerator, self).__init__() self.linear1 = nn.Linear(n_features, n_hid) self.sig = nn.Sigmoid() self.linear2 = nn.Linear(n_hid, n_hid // 4) self.linear3 = nn.Linear(n_hid // 4, n_channels * 2) def forward(self, x, shared_weights=None): if shared_weights is not None: # weight sharing self.linear1.weight = shared_weights[0] self.linear2.weight = shared_weights[1] x = self.linear1(x) x = self.sig(x) x = self.linear2(x) x = self.sig(x) x = self.linear3(x) out = self.sig(x) return out, [self.linear1.weight, self.linear2.weight] class FiLMlayer(Module): """Applies Feature-wise Linear Modulation to the incoming data .. seealso:: Perez, Ethan, et al. "Film: Visual reasoning with a general conditioning layer." Thirty-Second AAAI Conference on Artificial Intelligence. 2018. Args: n_metadata (dict): FiLM metadata see for more details. n_channels (int): Number of output channels. Attributes: batch_size (int): Batch size. height (int): Image height. width (int): Image width. feature_size (int): Number of features in data. generator (FiLMgenerator): FiLM network. gammas (float): Multiplicative term of the FiLM linear modulation. betas (float): Additive term of the FiLM linear modulation. """ def __init__(self, n_metadata, n_channels): super(FiLMlayer, self).__init__() self.batch_size = None self.height = None self.width = None self.depth = None self.feature_size = None self.generator = FiLMgenerator(n_metadata, n_channels) # Add the parameters gammas and betas to access them out of the class. self.gammas = None self.betas = None def forward(self, feature_maps, context, w_shared): data_shape = if len(data_shape) == 4: _, self.feature_size, self.height, self.width = data_shape elif len(data_shape) == 5: _, self.feature_size, self.height, self.width, self.depth = data_shape else: raise ValueError("Data should be either 2D (tensor length: 4) or 3D (tensor length: 5), found shape: {}".format(data_shape)) if torch.cuda.is_available(): context = torch.Tensor(context).cuda() else: context = torch.Tensor(context) # Estimate the FiLM parameters using a FiLM generator from the contioning metadata film_params, new_w_shared = self.generator(context, w_shared) # FiLM applies a different affine transformation to each channel, # consistent accross spatial locations if len(data_shape) == 4: film_params = film_params.unsqueeze(-1).unsqueeze(-1) film_params = film_params.repeat(1, 1, self.height, self.width) else: film_params = film_params.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) film_params = film_params.repeat(1, 1, self.height, self.width, self.depth) self.gammas = film_params[:, :self.feature_size, ] self.betas = film_params[:, self.feature_size:, ] # Apply the linear modulation output = self.gammas * feature_maps + self.betas return output, new_w_shared
[docs] class HeMISUnet(Module): """A U-Net model inspired by HeMIS to deal with missing contrasts. 1) It has as many encoders as contrasts but only one decoder. 2) Skip connections are the concatenations of the means and var of all encoders skip connections. Param: contrasts: list of all the possible contrasts. ['T1', 'T2', 'T2S', 'F'] .. seealso:: Havaei, M., Guizard, N., Chapados, N., Bengio, Y.: Hemis: Hetero-modal image segmentation. ArXiv link: Reuben Dorent and Samuel Joutard and Marc Modat and Sébastien Ourselin and Tom Vercauteren Hetero-Modal Variational Encoder-Decoder for Joint Modality Completion and Segmentation ArXiv link: Args: contrasts (list): List of contrasts. out_channel (int): Number of output channels. depth (int): Number of down convolutions minus bottom down convolution. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. **kwargs: Attributes: depth (int): Number of down convolutions minus bottom down convolution. contrasts (list): List of contrasts. Encoder_mod (ModuleDict): Contains encoder for each modality. decoder (Decoder): U-Net decoder. """
[docs] def __init__(self, contrasts, out_channel=1, depth=3, dropout_rate=0.3, bn_momentum=0.1, **kwargs): super(HeMISUnet, self).__init__() self.depth = depth self.contrasts = contrasts # Encoder path self.Encoder_mod = nn.ModuleDict( [['Encoder_{}'.format(Mod), Encoder(in_channel=1, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum)] for Mod in self.contrasts]) # Decoder path self.decoder = Decoder(out_channel=out_channel, depth=depth, dropout_rate=dropout_rate, bn_momentum=bn_momentum, hemis=True)
[docs] def forward(self, x_mods, indexes_mod): """ X is list like X = [x_T1, x_T2, x_T2S, x_F] indexes_mod: list of arrays like [[1, 1, 1], [1, 1, 0], [1, 0, 1], [1, 1, 0]] N.B. len(list) = number of contrasts. len(list[i]) = Batch size """ features_mod = [[] for _ in range(self.depth + 1)] # Down-sampling for i, Mod in enumerate(self.contrasts): features, _ = self.Encoder_mod['Encoder_{}'.format(Mod)](x_mods[i]) for j in range(self.depth + 1): features_mod[j].append(features[j].unsqueeze(0)) # Abstraction for j in range(self.depth + 1): features_cat =[j], 0).transpose(0, 1) features_mod[j] =[[features_cat[i][indexes_mod[i]].squeeze(1).mean(0), features_cat[i][indexes_mod[i]].squeeze(1).var(0)], 0).unsqueeze(0) for i in range(len(indexes_mod))], 0) # Up-sampling preds = self.decoder(features_mod) return preds
[docs] class Modified3DUNet(nn.Module): """Code from the following repository: The main differences with the original UNet resides in the use of LeakyReLU instead of ReLU, InstanceNormalisation instead of BatchNorm due to small batch size in 3D and the addition of segmentation layers in the decoder. If attention=True, attention gates are added in the decoder to help focus attention on important features for a given task. Code related to the attentions gates is inspired from: Args: in_channel (int): Number of channels in the input image. out_channel (int): Number of channels in the output image. n_filters (int): Number of base filters in the U-Net. attention (bool): Boolean indicating whether the attention module is on or not. dropout_rate (float): Probability of dropout. bn_momentum (float): Batch normalization momentum. final_activation (str): Choice of final activation between "sigmoid", "relu" and "softmax". **kwargs: Attributes: in_channels (int): Number of channels in the input image. n_classes (int): Number of channels in the output image. base_n_filter (int): Number of base filters in the U-Net. attention (bool): Boolean indicating whether the attention module is on or not. momentum (float): Batch normalization momentum. final_activation (str): Choice of final activation between "sigmoid", "relu" and "softmax". Note: All layers are defined as attributes and used in the forward method. """
[docs] def __init__(self, in_channel, out_channel, n_filters=16, attention=False, dropout_rate=0.3, bn_momentum=0.1, final_activation="sigmoid", n_metadata=None, film_layers=None, **kwargs): super(Modified3DUNet, self).__init__() self.in_channels = in_channel self.n_classes = out_channel self.base_n_filter = n_filters self.attention = attention self.momentum = bn_momentum self.final_activation = final_activation self.lrelu = nn.LeakyReLU() self.dropout3d = nn.Dropout3d(p=dropout_rate) self.upsacle = nn.Upsample(scale_factor=2, mode='nearest') self.softmax = nn.Softmax(dim=1) # Level 1 context pathway self.conv3d_c1_1 = nn.Conv3d( self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False ) self.film_layer1 = FiLMlayer(n_metadata, self.base_n_filter) if film_layers and film_layers[0] else None self.conv3d_c1_2 = nn.Conv3d( self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, bias=False ) self.lrelu_conv_c1 = self.lrelu_conv( self.base_n_filter, self.base_n_filter) self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter, momentum=self.momentum) # Level 2 context pathway self.conv3d_c2 = nn.Conv3d( self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1, bias=False ) self.film_layer2 = FiLMlayer(n_metadata, self.base_n_filter * 2) if film_layers and film_layers[1] else None self.norm_lrelu_conv_c2 = self.norm_lrelu_conv( self.base_n_filter * 2, self.base_n_filter * 2) self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2, momentum=self.momentum) # Level 3 context pathway self.conv3d_c3 = nn.Conv3d( self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1, bias=False ) self.film_layer3 = FiLMlayer(n_metadata, self.base_n_filter * 4) if film_layers and film_layers[2] else None self.norm_lrelu_conv_c3 = self.norm_lrelu_conv( self.base_n_filter * 4, self.base_n_filter * 4) self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4, momentum=self.momentum) # Level 4 context pathway self.conv3d_c4 = nn.Conv3d( self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1, bias=False ) self.film_layer4 = FiLMlayer(n_metadata, self.base_n_filter * 8) if film_layers and film_layers[3] else None self.norm_lrelu_conv_c4 = self.norm_lrelu_conv( self.base_n_filter * 8, self.base_n_filter * 8) self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8, momentum=self.momentum) # Level 5 context pathway, level 0 localization pathway self.conv3d_c5 = nn.Conv3d( self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1, bias=False ) self.norm_lrelu_conv_c5 = self.norm_lrelu_conv( self.base_n_filter * 16, self.base_n_filter * 16) if film_layers and film_layers[4]: self.norm_lrelu_0 = self.norm_lrelu(self.base_n_filter * 16) self.film_layer5 = FiLMlayer(n_metadata, self.base_n_filter * 16) self.upscale_conv_norm_lrelu_0 = self.upscale_conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 8) else: self.norm_lrelu_upscale_conv_norm_lrelu_l0 = \ self.norm_lrelu_upscale_conv_norm_lrelu( self.base_n_filter * 16, self.base_n_filter * 8) self.conv3d_l0 = nn.Conv3d( self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False ) self.film_layer6 = FiLMlayer(n_metadata, self.base_n_filter * 8) if film_layers and film_layers[5] else None self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8, momentum=self.momentum) # Attention UNet if self.attention: self.gating = UnetGridGatingSignal3(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=(1, 1, 1), is_batchnorm=True) # attention blocks self.attentionblock2 = GridAttentionBlockND(in_channels=self.base_n_filter * 2, gating_channels=self.base_n_filter * 8, inter_channels=self.base_n_filter * 2, sub_sample_factor=(2, 2, 2), ) self.attentionblock3 = GridAttentionBlockND(in_channels=self.base_n_filter * 4, gating_channels=self.base_n_filter * 8, inter_channels=self.base_n_filter * 4, sub_sample_factor=(2, 2, 2), ) self.attentionblock4 = GridAttentionBlockND(in_channels=self.base_n_filter * 8, gating_channels=self.base_n_filter * 8, inter_channels=self.base_n_filter * 8, sub_sample_factor=(2, 2, 2), ) self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 16, momentum=self.momentum) # Level 1 localization pathway self.conv_norm_lrelu_l1 = self.conv_norm_lrelu( self.base_n_filter * 16, self.base_n_filter * 16) self.conv3d_l1 = nn.Conv3d( self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, bias=False ) self.film_layer7 = FiLMlayer(n_metadata, self.base_n_filter * 4) if film_layers and film_layers[6] else None self.norm_lrelu_upscale_conv_norm_lrelu_l1 = \ self.norm_lrelu_upscale_conv_norm_lrelu( self.base_n_filter * 8, self.base_n_filter * 4) # Level 2 localization pathway self.conv_norm_lrelu_l2 = self.conv_norm_lrelu( self.base_n_filter * 8, self.base_n_filter * 8) self.conv3d_l2 = nn.Conv3d( self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0, bias=False ) self.film_layer8 = FiLMlayer(n_metadata, self.base_n_filter * 2) if film_layers and film_layers[7] else None self.norm_lrelu_upscale_conv_norm_lrelu_l2 = \ self.norm_lrelu_upscale_conv_norm_lrelu( self.base_n_filter * 4, self.base_n_filter * 2) # Level 3 localization pathway self.conv_norm_lrelu_l3 = self.conv_norm_lrelu( self.base_n_filter * 4, self.base_n_filter * 4) self.conv3d_l3 = nn.Conv3d( self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0, bias=False ) self.film_layer9 = FiLMlayer(n_metadata, self.base_n_filter) if film_layers and film_layers[8] else None self.norm_lrelu_upscale_conv_norm_lrelu_l3 = \ self.norm_lrelu_upscale_conv_norm_lrelu( self.base_n_filter * 2, self.base_n_filter) # Level 4 localization pathway self.conv_norm_lrelu_l4 = self.conv_norm_lrelu( self.base_n_filter * 2, self.base_n_filter * 2) self.conv3d_l4 = nn.Conv3d( self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False ) # self.film_layer10 = FiLMlayer(n_metadata, ) if film_layers and film_layers[9] else None self.ds2_1x1_conv3d = nn.Conv3d( self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False ) self.ds3_1x1_conv3d = nn.Conv3d( self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0, bias=False ) self.film_layer10 = FiLMlayer(n_metadata, self.n_classes) if film_layers and film_layers[9] else None
def conv_norm_lrelu(self, feat_in, feat_out): return nn.Sequential( nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm3d(feat_out, momentum=self.momentum), nn.LeakyReLU()) def norm_lrelu_conv(self, feat_in, feat_out): return nn.Sequential( nn.InstanceNorm3d(feat_in, momentum=self.momentum), nn.LeakyReLU(), nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) def lrelu_conv(self, feat_in, feat_out): return nn.Sequential( nn.LeakyReLU(), nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out): return nn.Sequential( nn.InstanceNorm3d(feat_in, momentum=self.momentum), nn.LeakyReLU(), nn.Upsample(scale_factor=2, mode='nearest'), # should be feat_in*2 or feat_in nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm3d(feat_out, momentum=self.momentum), nn.LeakyReLU()) def norm_lrelu(self, feat_in): return nn.Sequential( nn.InstanceNorm3d(feat_in, momentum=self.momentum), nn.LeakyReLU()) def upscale_conv_norm_lrelu(self, feat_in, feat_out): return nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), # should be feat_in*2 or feat_in nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), nn.InstanceNorm3d(feat_out, momentum=self.momentum), nn.LeakyReLU())
[docs] def forward(self, x, context=None, w_film=None): # Level 1 context pathway out = self.conv3d_c1_1(x) residual_1 = out out = self.lrelu(out) out = self.conv3d_c1_2(out) out = self.dropout3d(out) out = self.lrelu_conv_c1(out) # Element Wise Summation out += residual_1 out = self.lrelu(out) context_1 = out out = self.inorm3d_c1(out) out = self.lrelu(out) if hasattr(self, 'film_layer1') and self.film_layer1: out, w_film = self.film_layer1(out, context, w_film) # Level 2 context pathway out = self.conv3d_c2(out) residual_2 = out out = self.norm_lrelu_conv_c2(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c2(out) out += residual_2 out = self.inorm3d_c2(out) out = self.lrelu(out) if hasattr(self, 'film_layer2') and self.film_layer2: out, w_film = self.film_layer2(out, context, w_film) context_2 = out # Level 3 context pathway out = self.conv3d_c3(out) residual_3 = out out = self.norm_lrelu_conv_c3(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c3(out) out += residual_3 out = self.inorm3d_c3(out) out = self.lrelu(out) if hasattr(self, 'film_layer3') and self.film_layer3: out, w_film = self.film_layer3(out, context, w_film) context_3 = out # Level 4 context pathway out = self.conv3d_c4(out) residual_4 = out out = self.norm_lrelu_conv_c4(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c4(out) out += residual_4 out = self.inorm3d_c4(out) out = self.lrelu(out) if hasattr(self, 'film_layer4') and self.film_layer4: out, w_film = self.film_layer4(out, context, w_film) context_4 = out # Level 5 out = self.conv3d_c5(out) residual_5 = out out = self.norm_lrelu_conv_c5(out) out = self.dropout3d(out) out = self.norm_lrelu_conv_c5(out) out += residual_5 if self.attention: out = self.inorm3d_l0(out) out = self.lrelu(out) gating = self.gating(out) context_4, att4 = self.attentionblock4(context_4, gating) context_3, att3 = self.attentionblock3(context_3, gating) context_2, att2 = self.attentionblock2(context_2, gating) if hasattr(self, 'film_layer5') and self.film_layer5: out = self.norm_lrelu_0(out) out, w_film = self.film_layer5(out, context, w_film) out = self.upscale_conv_norm_lrelu_0(out) else: out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) out = self.conv3d_l0(out) out = self.inorm3d_l0(out) out = self.lrelu(out) if hasattr(self, 'film_layer6') and self.film_layer6: out, w_film = self.film_layer6(out, context, w_film) # Level 1 localization pathway out =[out, context_4], dim=1) out = self.conv_norm_lrelu_l1(out) out = self.conv3d_l1(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) if hasattr(self, 'film_layer7') and self.film_layer7: out, w_film = self.film_layer7(out, context, w_film) # Level 2 localization pathway out =[out, context_3], dim=1) out = self.conv_norm_lrelu_l2(out) ds2 = out out = self.conv3d_l2(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) if hasattr(self, 'film_layer8') and self.film_layer8: out, w_film = self.film_layer8(out, context, w_film) # Level 3 localization pathway out =[out, context_2], dim=1) out = self.conv_norm_lrelu_l3(out) ds3 = out out = self.conv3d_l3(out) out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) if hasattr(self, 'film_layer9') and self.film_layer9: out, w_film = self.film_layer9(out, context, w_film) # Level 4 localization pathway out =[context_1, out], dim=1) out = self.conv_norm_lrelu_l4(out) out_pred = self.conv3d_l4(out) ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv) ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle( ds1_ds2_sum_upscale_ds3_sum) out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale if hasattr(self, 'film_layer10') and self.film_layer10: out, w_film = self.film_layer10(out, context, w_film) seg_layer = out if hasattr(self, "final_activation") and self.final_activation not in ["softmax", "relu", "sigmoid"]: raise ValueError("final_activation value has to be either softmax, relu, or sigmoid") elif hasattr(self, "final_activation") and self.final_activation == "softmax": out = self.softmax(out) elif hasattr(self, "final_activation") and self.final_activation == "relu": out = nn.ReLU()(seg_layer) / nn.ReLU()(seg_layer).max() if bool(nn.ReLU()(seg_layer).max()) \ else nn.ReLU()(seg_layer) else: out = torch.sigmoid(out) if self.n_classes > 1: # Remove background class out = out[:, 1:, ] return out
class UNet3D(Modified3DUNet): """To ensure retrocompatibility, when calling UNet3D (old model name), Modified3DUNet will be called. see Modified3DUNet to learn more about parameters. """ def __init__(self, in_channel, out_channel, n_filters=16, attention=False, dropout_rate=0.3, bn_momentum=0.1, final_activation="sigmoid", n_metadata=None, film_layers=None, **kwargs): super(UNet3D, self).__init__() Modified3DUNet(in_channel=in_channel, out_channel=out_channel, n_filters=n_filters, attention=attention, dropout_rate=dropout_rate, bn_momentum=bn_momentum, final_activation=final_activation, n_metadata=n_metadata, film_layers=film_layers, **kwargs) class GridAttentionBlockND(nn.Module): """Attention module to focus on important features passed through U-Net's decoder Specific to Attention UNet .. seealso:: Oktay, Ozan, et al. "Attention u-net: Learning where to look for the pancreas." arXiv preprint arXiv:1804.03999 (2018). Args: in_channels (int): Number of channels in the input image. gating_channels (int): Number of channels in the gating step. inter_channels (int): Number of channels in the intermediate gating step. dimension (int): Value of 2 or 3 to indicating whether it is used in a 2D or 3D model. sub_sample_factor (tuple or list): Convolution kernel size. Attributes: in_channels (int): Number of channels in the input image. gating_channels (int): Number of channels in the gating step. inter_channels (int): Number of channels in the intermediate gating step. dimension (int): Value of 2 or 3 to indicating whether it is used in a 2D or 3D model. sub_sample_factor (tuple or list): Convolution kernel size. upsample_mode (str): 'bilinear' or 'trilinear' related to the use of 2D or 3D models. W (Sequential): Sequence of convolution and batch normalization layers. theta (Conv2d or Conv3d): Convolution layer for gating operation. phi (Conv2d or Conv3d): Convolution layer for gating operation. psi (Conv2d or Conv3d): Convolution layer for gating operation. """ def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, sub_sample_factor=(2, 2, 2)): super(GridAttentionBlockND, self).__init__() assert dimension in [2, 3] # Downsampling rate for the input featuremap if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension # Default parameter set self.dimension = dimension self.sub_sample_kernel_size = self.sub_sample_factor # Number of channels (pixel dimensions) self.in_channels = in_channels self.gating_channels = gating_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d bn = nn.BatchNorm3d self.upsample_mode = 'trilinear' elif dimension == 2: conv_nd = nn.Conv2d bn = nn.BatchNorm2d self.upsample_mode = 'bilinear' else: raise NotImplementedError # Output transform self.W = nn.Sequential( conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels), ) # Theta^T * x_ij + Phi^T * gating_signal + bias self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0, bias=True) self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) # Initialise weights for m in self.children(): m.apply(weights_init_kaiming) # Define the operation self.operation_function = self._concatenation def forward(self, x, g): output = self.operation_function(x, g) return output def _concatenation(self, x, g): input_size = x.size() batch_size = input_size[0] assert batch_size == g.size(0) # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) # phi => (b, g_d) -> (b, i_c) theta_x = self.theta(x) theta_x_size = theta_x.size() # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode, align_corners=True) f = F.relu(theta_x + phi_g, inplace=True) # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) sigm_psi_f = torch.sigmoid(self.psi(f)) # upsample the attentions and multiply sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode, align_corners=True) y = sigm_psi_f.expand_as(x) * x W_y = self.W(y) return W_y, sigm_psi_f def weights_init_kaiming(m): """Initialize weights according to method describe here: """ classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal_(, a=0, mode='fan_in') elif classname.find('BatchNorm') != -1: init.normal_(, 1.0, 0.02) init.constant_(, 0.0) class UnetGridGatingSignal3(nn.Module): """Operation to extract important features for a specific task using 1x1x1 convolution (Gating) which is used in the attention blocks. Args: in_size (int): Number of channels in the input image. out_size (int): Number of channels in the output image. kernel_size (tuple): Convolution kernel size. is_batchnorm (bool): Boolean indicating whether to apply batch normalization or not. Attributes: conv1 (Sequential): 3D convolution, batch normalization and ReLU activation. """ def __init__(self, in_size, out_size, kernel_size=(1, 1, 1), is_batchnorm=True): super(UnetGridGatingSignal3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): weights_init_kaiming(m) def forward(self, inputs): outputs = self.conv1(inputs) return outputs class ConvBlock(nn.Module): def __init__(self, in_chan, out_chan, ksize=3, stride=1, pad=0, activation=nn.LeakyReLU()): """ Perform convolution, activation and batch normalization. Args: in_chan (int): number of channels on input out_chan (int): number of channel on output ksize (int): size of kernel for the 2d convolution stride (int): strides for 2d convolution pad (int): pad for nn.conv2d activation (nn.layers): activation layer. default Leaky ReLu """ super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_chan, out_chan, kernel_size=ksize, stride=stride, padding=pad) self.activation = activation self.batch_norm = nn.BatchNorm2d(out_chan) def forward(self, x): return self.activation(self.batch_norm(self.conv1(x))) class SimpleBlock(nn.Module): def __init__(self, in_chan, out_chan_1x1, out_chan_3x3, activation=nn.LeakyReLU()): """ Inception module with 3 convolutions with different kernel size with separate activation. The 3 outputs are then concatenated. Max pooling performed on concatenation. Args: in_chan (int): number of channel of input out_chan_1x1 (int): number of channel after first convolution block out_chan_3x3 (int): number of channel for the other convolution blocks activation (nn.layers): activation layer used in convolution block """ super(SimpleBlock, self).__init__() self.conv1 = ConvBlock(in_chan, out_chan_1x1, ksize=3, pad=0, activation=activation) self.conv2 = ConvBlock(in_chan, out_chan_3x3, ksize=5, pad=1, activation=activation) self.conv3 = ConvBlock(in_chan, out_chan_3x3, ksize=9, pad=3, activation=activation) self.MP = nn.MaxPool2d(1) def forward(self, x): conv1_out = self.conv1(x) conv2_out = self.conv2(x) conv3_out = self.conv3(x) output =[conv1_out, conv2_out, conv3_out], 1) output = self.MP(output) return output
[docs] class Countception(Module): """Countception model. Fully convolutional model using inception module and used for keypoints detection. The inception model extracts several patches within each image. Every pixel is therefore processed by the network several times, allowing to average multiple predictions and minimize false negatives. .. seealso:: Cohen JP et al. "Count-ception: Counting by fully convolutional redundant counting." Proceedings of the IEEE International Conference on Computer Vision Workshops. 2017. Args: in_channel (int): number of channels on input image out_channel (int): number of channels on output image use_logits (bool): boolean to change output logits_per_output (int): number of outputs of final convolution which will multiplied by the number of channels name (str): model's name used for call in configuration file. """
[docs] def __init__(self, in_channel=3, out_channel=1, use_logits=False, logits_per_output=12, name='CC', **kwargs): super(Countception, self).__init__() # params self.in_channel = in_channel self.out_channel = out_channel self.activation = nn.LeakyReLU(0.01) self.final_activation = nn.LeakyReLU(0.3) self.patch_size = 40 self.use_logits = use_logits self.logits_per_output = logits_per_output torch.LongTensor() self.conv1 = ConvBlock(self.in_channel, 64, ksize=3, pad=self.patch_size, activation=self.activation) self.simple1 = SimpleBlock(64, 16, 16, activation=self.activation) self.simple2 = SimpleBlock(48, 16, 32, activation=self.activation) self.conv2 = ConvBlock(80, 16, ksize=14, activation=self.activation) self.simple3 = SimpleBlock(16, 112, 48, activation=self.activation) self.simple4 = SimpleBlock(208, 64, 32, activation=self.activation) self.simple5 = SimpleBlock(128, 40, 40, activation=self.activation) self.simple6 = SimpleBlock(120, 32, 96, activation=self.activation) self.conv3 = ConvBlock(224, 32, ksize=20, activation=self.activation) self.conv4 = ConvBlock(32, 64, ksize=10, activation=self.activation) self.conv5 = ConvBlock(64, 32, ksize=9, activation=self.activation) if use_logits: self.conv6 = nn.ModuleList([ConvBlock( 64, logits_per_output, ksize=1, activation=self.final_activation) for _ in range(out_channel)]) else: self.conv6 = ConvBlock(32, self.out_channel, ksize=20, pad=1, activation=self.final_activation) # Weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): init.xavier_uniform_(m.weight, gain=init.calculate_gain('leaky_relu', param=0.01)) elif isinstance(m, nn.BatchNorm2d):
[docs] def forward(self, x): net = self.conv1(x) # 32 net = self.simple1(net) net = self.simple2(net) net = self.conv2(net) net = self.simple3(net) net = self.simple4(net) net = self.simple5(net) net = self.simple6(net) net = self.conv3(net) net = self.conv4(net) net = self.conv5(net) if self.use_logits: net = [c(net) for c in self.conv6] [self._print(n) for n in net] else: net = self.conv6(net) return net
def set_model_for_retrain(model_path, retrain_fraction, map_location, reset=True): """Set model for transfer learning. The first layers (defined by 1-retrain_fraction) are frozen (i.e. requires_grad=False). The weights of the last layers (defined by retrain_fraction) are reset unless reset option is False. Args: model_path (str): Pretrained model path. retrain_fraction (float): Fraction of the model that will be retrained, between 0 and 1. If set to 0.3, then the 30% last fraction of the model will be re-initalised and retrained. map_location (str): Device. reset (bool): if the un-frozen weight should be reset or kept as loaded. Returns: torch.Module: Model ready for retrain. """ # Load pretrained model model = torch.load(model_path, map_location=map_location) # Get number of layers with learnt parameters layer_names = [name for name, layer in model.named_modules() if hasattr(layer, 'reset_parameters')] n_layers = len(layer_names) # Compute the number of these layers we want to freeze n_freeze = int(round(n_layers * (1 - retrain_fraction))) # Last frozen layer last_frozen_layer = layer_names[n_freeze] # Set freeze first layers for name, layer in model.named_parameters(): if not name.startswith(last_frozen_layer): layer.requires_grad = False else: break # Reset weights of the last layers if reset: for name, layer in model.named_modules(): if name in layer_names[n_freeze:]: layer.reset_parameters() return model def get_model_filenames(folder_model): """Get trained model filenames from its folder path. This function checks if the folder_model exists and get trained model path (.pt or .onnx based on model and GPU availability) and its configuration file (.json) from it. Args: folder_name (str): Path of the model folder. Returns: str, str: Paths of the model (.pt or .onnx) and its configuration file (.json). """ if Path(folder_model).is_dir(): prefix_model = Path(folder_model).name fname_model_onnx = Path(folder_model, prefix_model + '.onnx') fname_model_pt = Path(folder_model, prefix_model + '.pt') cuda_available = torch.cuda.is_available() # Assign '.pt' or '.onnx' model based on file existence and GPU/CPU device availability if not fname_model_pt.is_file() and not fname_model_onnx.is_file(): raise FileNotFoundError(f"Model files not found in model folder: " f"'{str(fname_model_onnx)}' or '{str(fname_model_pt)}'") # '.pt' is preferred on GPU, or on CPU if '.onnx' doesn't exist elif (( cuda_available and fname_model_pt.is_file()) or (not cuda_available and not fname_model_onnx.is_file())): fname_model = fname_model_pt # '.onnx' is preferred on CPU, or on GPU if '.pt' doesn't exist elif ((not cuda_available and fname_model_onnx.is_file()) or ( cuda_available and not fname_model_pt.is_file())): fname_model = fname_model_onnx fname_model_metadata = Path(folder_model, prefix_model + '.json') if not fname_model_metadata.is_file(): raise FileNotFoundError(f"Model config file not found in model folder: '{str(fname_model_metadata)}'") else: raise FileNotFoundError(folder_model) return str(fname_model), str(fname_model_metadata)