Using attention for medical image segmentation

Original article can be found here (source): Artificial Intelligence on Medium

Using attention for medical image segmentation

Exploring two recent papers concerning the use of attention for segmentation, getting some intuition and a short PyTorch implementation.

The attention mechanism has been among the hottest areas of deep learning research over the last few years, starting by natural language processing and more recently in computer vision tasks. In this article, we will focus on how attention has impacted most recent architectures for medical image segmentation. For this, we will describe the architectures presented in two recent papers and try to give some intuition as to what happens, in the hope that it will give you some ideas of how you can apply attention to your own problem. We will also see simple PyTorch implementations.

Segmentation of medical images differs from natural images on two main points :

  • most medical images are very similar since they are taken in standardized settings, this means little variation in terms of orientation, position in the image, range of pixels, …
  • there is often a great imbalance between the positive class pixels (or voxels) and the negative class, for example when trying to segment tumors

Note : of course both the code and explanations are simplifications of the complex architectures described in the papers, the aim is mostly to give an intuition and a good idea of what is done and not to explain every detail.

1. Attention UNet

UNet is the go-to architecture for segmentation and most of the current advances in segmentation use this architecture as a backbone. In this paper, the authors propose a way to apply the attention mechanism to a standard UNet. If you want a refresher of how a standard UNet works, this article does a perfect job.

1.1. What is proposed

The architecture uses the standard UNet as a backbone, and the contracting path is not changed. What changes is the expanding path, and more precisely, the attention mechanism is integrated into the skip connections.

Block diagram of the attention UNet [1], with an expanding path block in red.

To explain how a block of the expanding path works, let’s call g the input to this block coming from the previous block, and x the skip connection coming from the expanding path. The following equations sum up how this block works.

The upsample block is pretty straight forward, and the ConvBlock is simply a sequence of two (convolution + batch norm + ReLU) blocks. The only block left to explain is the attention block.

Block diagram of the attention block [1]. The dimensions here assume a 3-dimensional input image.
  • both x and g are fed into 1×1 convolutions, to bring them to the same number of channels, without changing the size
  • after an upsampling operation (to have the same size), they are summed and passed through a ReLU
  • another 1×1 convolution and a sigmoid, to flatten to a single channel with a 0-to-1 score of how much importance to give to each part of the map
  • this attention map is then multiplied by the skip input to produce the final output of this attention block

1.2. Why it works

In a UNet, the contracting path can be seen as an encoder and the expanding path as a decoder. What is interesting about the UNet is the fact that the skip connections allows for features extracted by the encoder to be used directly during the decoder. This way, when “reconstructing” the mask of the image, the network learns to use these features, because the features of the contracting path are concatenated with those of the expanding path.

Applying an attention block before this concatenation allows for the network to put more weight on the features of the skip connection that will be relevant. It allows for the direct connection to focus on a particular part of the input, rather than feeding in every feature.

The attention distribution is multiplied by the skip connection feature map to only keep the important parts. This attention distribution is extracted from what is called query (the input) and the values (the skip connection). The attention operations allow for a selective picking of the information contained in the values. This selection is based on the the query.

To summarize : the input and skip connection are used to decide what parts of the skip connection to focus on. Then we use this subset of the skip connection, along with the input in a standard expanding path.

1.3. Short implementation

The following code defines (a simplified version of) the attention block and the “up-block” that is used for the expanding path of the UNet. The “down-block” is unchanged from the original UNet.

Note : the ConvBatchNorm is a sequence of a Conv2d, BatchNorm2d and a ReLU activation function.

2. Multi-scale guided attention

The second architecture we will talk about, is a bit more original that the first one. It does not rely on a UNet architecture, but on feature extraction followed by guided attention blocks.

Block diagram of the proposed solution [2].

The first part is to extract the features from the image. For this, the input image is fed to a pretrained ResNet, and we extract the feature maps at 4 different levels. This is interesting as low level features tend to be present at the beginning of the network and high level features towards the end of the network, so we will have access to features at multiple scales. All feature maps are upsampled to the size of the biggest one using bilinear interpolation. This gives us 4 feature maps of the same size, which are concatenated and fed into a convolutional block. The output of this convolutional block (the multi-scale feature map) is concatenated with each of the 4 feature maps, to give us the input of the attention blocks, which are a bit more complex than the previous ones.

2.1. What is proposed

The guided attention block relies on the position and channel attention modules, which we will start by describing.

Block diagram of the position and channel attention modules [2].

We’ll try to understand what is going on in these modules, but we won’t go into too much detail of every operation in these two blocks (which can be understood by the code section below).

These two blocks are actually very similar, the only difference between them resides in the operations which tries to extract either information from the channel or the position. Applying a convolution before flattening gives more importance to the position, because the number of channels is diminished during the convolution. In the channel attention module, more weight is given to the channels as the original number of channels is kept during the reshape operation.

In each block, it is important to notice that the top two branches are in charge of extracting the specific attention distribution. For example, in the position attention module, we have a (W*H)x(W*H) attention distribution, where the (i, j) element tells us how much position i impacts position j. In the channel block, we have a CxC attention distribution which tells us how much one channel impacts another. In the third branch of each module, this specific attention distribution is multiplied by a transformation of the input to get the channel or position attention distribution. As in the previous article, the attention distribution is then multiplied by the input to extract the relevant information of the input, given the multi-scale feature context. The outputs of both these modules are then summed element-wise to give the final self-attention features. Now let’s see how the output of these two modules are used in the global framework.

Block diagram of the guided attention module with 2 refinement steps [2].

Guided attention is built from a succession of multiple refinement steps for each scale (4 scales in the proposed architecture). The input feature map is fed to the position and channel output module, which outputs a single feature map. It is also passed through an autoencoder, which produces a reconstructed version of the input. In each block, the attention map is produced by a multiplication these two outputs. This attention map is then multiplied by the multi-scale feature map produced previously. The output therefore represents which parts of the multi scale we should focus on, given the specific scale on which we are working. You can then have a sequence of such guided attention modules, by concatenating the output of one block with the multi scale attention map, and giving this as the input to the following block.

Two additional losses are necessary here to ensure that the refinement steps work correctly :

  • the standard reconstruction loss to make sure that the autoencoder correctly reconstruct the input feature map
  • a guided loss, which tries to minimize the distance between the two subsequent latent representations of the input

Each attention feature is then passed through a convolution block to predict the mask. To produce the final prediction, an average of the four masks is taken, which can be seen as a sort of ensembling of the model for various scales fo features.

2.2. Why it works

As this architecture is much more complex than the previous one, it can be hard to understand what is going on behind the attention modules. Here is how I understand the contribution of the various blocks.

The position attention module tries to specify which position of the specific scale features to focus on, based on the multi-scale representation of the input image. The channel attention modules does the same thing, by specifying how much to pay attention to which channel. The specific operations used in either block is what gives more importance to the channel or the position information in the attention distribution. Combining both modules gives us an attention map giving a score to each position-channel pair, i.e. each element of the feature map.

The autoencoder is used to make sure that the latent-representation of the feature map is not completely changed from on step to the other. As the latent space is lower dimensional, only the key information will be extracted. We do not want this information to be changed from one refinement step to the next, we only want small adjustments to be made. These won’t be seen in the latent representation.

The use of a sequence of guided attention modules allows for the final attention map to be refined and to progressively make noise vanish and give more weight to the really important regions.

Ensembling a few such networks with multiple scales allows for the network to have a vision of both global and local features. These features are then combined into a multi-scale feature map. Applying attention to the multi-scale feature maps along with each specific scale allows to better understand which features bring more value to the final output.

2.3. Short implementation

Short implementation of the position attention module, channel attention module, and one guided attention block.

Takeaway

So, what can be taken away from these articles ? Attention can be seen as a mechanism that helps to point out to the networks, which extracted features to focus on, based on the context.

In the UNet, this means, which features extracted during the contracting path we should pay more attention to, given the features extracted during the expanding path. This helps to give more sense to the skip connections, i.e. pass relevant information rather than every extracted feature. In the second article, which multi-scale features should we focus on, given the current scale on which we are working.

The general idea can be applied to a wide range of problems, and I think that seeing multiple examples can help to better understand how attention can be adapted.

References

[1] O. Oktay, et al. “Attention UNet : learning where to look for the pancreas” (2018)
[2] A. Sinha and J. Dolz. “Multi-scale self-guided attention for medical image segmentation” (2020)