SEMANTIC SEGMENTATION ON MEDICAL IMAGES

Source: Deep Learning on Medium

SEMANTIC SEGMENTATION ON MEDICAL IMAGES

We will train a deep learning model with architecture known as the U-Net, on an Electron Microscopy Dataset. Our model will learn to transform a grayscale EM image of nerve cells (left-one) into an accurate boundary map differentiating the walls in between (right-side) at pixel level as shown above. U-Net remains the state-of-the art for performing semantic segmentation and the same model with minor hyper-parameter tuning and with an experimental head, can be used for almost any image segmentation problem.

Note : This tutorial expects the reader to have a idea on how convolutional neural networks work, and my implementation of the U-Net will mostly be using Keras functional API, so if you are new to deep learning i suggest you to gain some good insight by reading and implementing my medium article on how convolutional neural networks work by clicking here.

  • The model that we’ll be building in this post was compiled on a Nvidia GTX 1060 graphics card, it would take several hours to train if you compile it on a CPU, in order to achieve good accuracy, i would suggest running it on the GPU version of Keras if you have a GPU.

The Deep learning model that I will be building in this post is based on this paper U-Net: Convolutional Networks for Biomedical Image Segmentation which still remains state-of-the-art in image segmentation for tasks other than medical images. This paper has introduced a new architecture for doing semantic segmentation which is significantly better than the once which came before this, most of the approaches were using a sliding window convolutional neural networks and this is a significant departure for that in every way.

Before we jump into the theory behind our neural network, i will first introduce you to what kind of visual recognition tasks we can be seeing in computer vision area of machine learning.

So the most simple one is image classification (a) where we are trying to retrieve information of what is in the image, but here the problem is we have no idea where a certain object class in located and how many of its instances are present in the image and so on. And hence later on, object localisation/detection (b) emerged, which not only tells us what is in the picture but also where is it located, which is very helpful. But then even this approach gives us only the boundary boxes, rectangles marked over the object located in the image. A deeper level of this object localisation is Semantic Segmentation, which is the main topic of this article. Semantic Segmentation can be described as per pixel classification for images, here we label each pixel with it’s respective class as shown below:

The above image is one of the real-world example where semantic segmentation is being applied as a part of building self-driving cars to better understand the environment around them.

But the model we will be building today is to segment bio-medical images, and the paper that i am implementing to do that was published in 2015 which stood exceptional in winning the ISBI challenge 2015. This architecture can be applied where the training data is very less. Especially in medical sectors the training samples available are very less, specifically because the domain expertise is very limited and it’s very hard to get really well labelled and high quality data, but U-Net still remains state-of-the-art in solving such tasks.

Theory behind U-Net

Key Points :

  • The network can be divided into two paths, one is the contracting path and the other is an expanding path.
  • The contracting path performs down-sampling for feature extraction, constructed same as a convolutional neural network but followed by an expanding path that performs up-sampling for precise localisation of features in the higher resolution layers.
  • Another important aspect that makes the network so special is taking the convolution layer feature maps that are trained in the down-sampling path and concatenating them to the corresponding de-convolution layers of upsampling path.

Network Architecture :

The above image is describing U-Net architecture, taken from the base paper. So as mentioned earlier, our network will have 2 paths, a down-sampling path, and an upsampling path.

Down-sampling path :

The left-side of the network is the down-sampling part, it’s the path where we are running the image through multiple convolutional layers and adding max-pooling in between to downsample and reduce the size of the image, simultaneously increasing the number of layers by doubling the number of filters of convolutional layers on each convolution block.

  • The path has 4 convolution blocks (2 convolutions each), followed by max-pooling layers of size 2×2 with stride 2 for downsampling.
  • The 5th convolution block is not followed by max-pooling rather is connected to the up-sampling path.
  • The first convolution block contains 64 filters on each convolution layer.
  • Number of filters are doubled with each consecutive convolution block.
  • Resolution is reduced with increasing depth(Number of layers),
  • No padding is used (‘valid’ padding).
  • The convolution filters are of size 3×3 with ReLU as activation function.

Up-sampling path :

Here, in up-sampling path we are replacing the pooling layers with upsampling operators which are increasing the resolution of the output.

  • The up-sampling path remains symmetric to the down-sampling path, turning the network into a U shaped neural network, hence the name “U-Net”.
  • There are 4 convolution blocks with 2 convolution layers in each, followed by transposed convolution/up-convolution layer. Where a transposed convolution upsamples the image using a learned filter.
  • Number of filters for each consecutive convolution block equals half of the filters from previous convolution block.
  • Resolution is increased with reducing the depth (Number of layers).
  • No padding is used here also (‘valid’ padding).
  • The convolution filters are of size 3×3 with ReLU as activation function.
  • The corresponding feature maps from the down-sampling path are concatenated to the respective up-sampling layers for achieving precise localisation.
  • The final convolution layer has a filter of 1×1 size to map each of 64 component feature vector to the desired number of classes(in this case, it’s the cell and background).

Note: The convolutional kernel that is learned during the down-sampling path is used to transform the image from a small domain to a big domain during the up-sampling path (hence the inter-connections between the paths).

Implementation of U-Net

The dataset we will be using in this tutorial will be the 2015 ISBI cell tracking challenge dataset. It contains 30 Electroscope images with their respective annotated images(labels). You can find the dataset and the code explained in this tutorial on by github. So let us construct the model in Keras.

# unet_model.py
from keras.models import Model
from keras.backend import int_shape
from keras.layers import BatchNormalization, Conv2D, Conv2DTranspose, MaxPooling2D, Dropout, UpSampling2D, Input, concatenate

Let us look at what we are importing and why :

‘Model ‘ is from Keras functional API, used for building complex deep learning models, directed acyclic graphs, etc. And ‘int_shape’ returns the shape of a tensor or a variable as a tuple of int or None entries.

‘BatchNormalization’ : Normalises the output of activations from the layers.

‘Conv2D’ : Used to create convolution layer.

‘Conv2DTranspose’ : To perform a transposed convolution.

‘MaxPooling2D’ : Does max pooling operation on spatial data.

‘Dropout’ : Used for dropping units (hidden and visible) in a neural network.

‘Input’ : Used to instantiate a Keras tensor.

‘concatenate’ : Returns a tensor which is the concatenation of inputs alongside the axis passed.

def upsample_conv(filters, kernel_size, strides, padding):
return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)
def upsample_simple(filters, kernel_size, strides, padding):
return UpSampling2D(strides)

The above two functions are perform two different kinds of upsampling.

The ‘upsampling_conv ‘ function performs a transposed convolution operation, which means, upsampling an image based on a learned filter. And we are making use of ‘Conv2DTranspose ‘ to do it. The parameters passed to do it are self explanatory. Click here to if not sure.

The ‘upsample_simple ‘ function performs a simple straight forward upsampling operation on an image with a kernel of specified size. We are making use of ‘Upsampling2D ‘ to do it.

def conv2d_block(
inputs,
use_batch_norm=True,
dropout=0.3,
filters=16,
kernel_size=(3,3),
activation='relu',
kernel_initializer='he_normal',
padding='same'):

c = Conv2D(filters, kernel_size, activation=activation,
kernel_initializer=kernel_initializer, padding=padding) (inputs)
if use_batch_norm:
c = BatchNormalization()(c)
if dropout > 0.0:
c = Dropout(dropout)(c)
c = Conv2D(filters, kernel_size, activation=activation,
kernel_initializer=kernel_initializer, padding=padding) (c)
if use_batch_norm:
c = BatchNormalization()(c)
return c

The ‘conv2d_block ‘ above function is going to handle convolution operations in the network. We are making use of the classic ‘Conv2D’ function from Keras in order to perform the convolution operations. The arguments that can be passed are the input-size, choosing to use batch normalisation within the layers, dropout rate, number of filters, kernel size, activation function to use, kernel initialiser ‘he_normal’(to set the initial weights of the network completely random) and finally padding(‘same’ in our case, i.e the layer’s outputs will have the same spatial dimensions as its inputs).

def unet_model(
input_shape,
num_classes=1,
use_batch_norm=True,
upsample_mode='deconv', # 'de-convolution' or 'simple upsampling'
use_dropout_on_upsampling=False,
dropout=0.3,
dropout_change_per_layer=0.0,
filters=16,
num_layers=4,
output_activation='sigmoid'): # 'sigmoid' or 'softmax'

if upsample_mode=='deconv':
upsample=upsample_conv
else:
upsample=upsample_simple
# Build U-Net model
inputs = Input(input_shape)
x = inputs
down_layers = []
for l in range(num_layers):
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm,
dropout=dropout)
down_layers.append(x)
x = MaxPooling2D((2, 2)) (x)
dropout += dropout_change_per_layer
filters = filters*2 # double the number of filters with each layer
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm,
dropout=dropout)
if not use_dropout_on_upsampling:
dropout = 0.0
dropout_change_per_layer = 0.0
for conv in reversed(down_layers):
filters //= 2 # decrease the number of filters with each layer
dropout -= dropout_change_per_layer
x = upsample(filters, (2, 2), strides=(2, 2), padding='same') (x)
x = concatenate([x, conv])
x = conv2d_block(inputs=x, filters=filters, use_batch_norm=use_batch_norm,
dropout=dropout)

outputs = Conv2D(num_classes, (1, 1), activation=output_activation) (x)

model = Model(inputs=[inputs], outputs=[outputs])
return model

The above function ‘unet_model’ completes the whole model of u-net. Click here to see the graphical structure of the above model.

# train.pyimport numpy as np
import matplotlib.pyplot as plt
import glob
import os
import sys
from PIL import Image
masks = glob.glob("./dataset/isbi2015/train/label/*.png")
orgs = glob.glob("./dataset/isbi2015/train/image/*.png")

The above function ‘unet_model’ completes the whole model of u-net. Click here to see the graphical structure of the above model.

# train.pyimport numpy as np
import matplotlib.pyplot as plt
import glob
import os
import sys
from PIL import Image
masks = glob.glob("./dataset/isbi2015/train/label/*.png")
orgs = glob.glob("./dataset/isbi2015/train/image/*.png")

We are importing the dataset in the above code using ‘glob’. Make sure to download or clone my github repository to find the dataset.

imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
imgs_list.append(np.array(Image.open(image).resize((512,512))))

im = Image.open(mask).resize((512,512))
masks_list.append(np.array(im))
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)

Here we have initialised two lists, converting the raw images and the annotated (labels) images to a resolution of 512×512 and appending them to ‘imgs_list’ and ‘masks_list’ respectively.

from sklearn.model_selection import train_test_splitx_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.5, random_state=0)

Here we are splitting our imported dataset into training set and validation set by making use of the function ‘train_test_split’ function from sklearn. We have have chosen 15 images for training set and other 15 images as the test set.

from utils import get_augmentedtrain_gen = get_augmented(
x_train, y_train, batch_size=2,
data_gen_args = dict(
rotation_range=15.,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=50,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True,
fill_mode='constant'
))

The above function is used for performing data augmentation on our dataset. It is making use of ‘utils.py’ file included in my github to import ‘get_augmented’ function which is utilising ‘ImageDataGenerator’ from ‘keras.preprocessing.image’ within. The names of parameters passed in the above function describe the types of augmentations performed.

sample_batch = next(train_gen)
xx, yy = sample_batch
print(xx.shape, yy.shape)from keras_unet.utils import plot_imgs
plot_imgs(org_imgs=xx, mask_imgs=yy, nm_img_to_plot=2, figsize=6)

You can plot and look into the augmented images by running the above code snippet.

from unet_model import unet_modelinput_shape = x_train[0].shapemodel = unet_model(
input_shape,
num_classes=1,
filters=64,
dropout=0.2,
num_layers=4,
output_activation='sigmoid'
)
print(model.summary())

Initialising the network and printing summary of the model implemented.

from keras.callbacks import ModelCheckpoint
model_filename = 'segm_model_v0.h5'
callback_checkpoint = ModelCheckpoint(
model_filename,
verbose=1,
monitor='val_loss',
save_best_only=True,
)
from keras.optimizers import Adam, SGD
from metrics import iou, iou_thresholded
model.compile(
optimizer=SGD(lr=0.01, momentum=0.99),
loss='binary_crossentropy',
metrics=[iou, iou_thresholded]
)

Here we are compiling the above model by using Stochastic Gradient Descent as our optimizer with a learning rate of 0.01. And ‘binary_crossentropy’ as our loss function.

history = model.fit_generator(
train_gen,
steps_per_epoch=100,
epochs=10,

validation_data=(x_val, y_val),
callbacks=[callback_checkpoint]
)

The above code will train the model and the figure below has the plot of loss and accuracy of the training :

Once the training is done, the weights of our trained network will be saved within the same directory as a file named with ‘.h5’ extension.

# test.pyfrom PIL import Image
import numpy as np
import glob
masks = glob.glob("./dataset/isbi2015/train/label/*.png")
orgs = glob.glob("./dataset/isbi2015/train/image/*.png")
imgs_list = []
masks_list = []
for image, mask in zip(orgs, masks):
imgs_list.append(np.array(Image.open(image).resize((512,512))))

im = Image.open(mask).resize((512,512))
masks_list.append(np.array(im))
imgs_np = np.asarray(imgs_list)
masks_np = np.asarray(masks_list)
x = np.asarray(imgs_np, dtype=np.float32)/255
y = np.asarray(masks_np, dtype=np.float32)/255
y = y.reshape(y.shape[0], y.shape[1], y.shape[2], 1)
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1)
from sklearn.model_selection import train_test_splitx_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=0)
from unet_model import unet_modelinput_shape = x_train[0].shapemodel = unet_model(
input_shape,
num_classes=1,
filters=64,
dropout=0.2,
num_layers=4,
output_activation='sigmoid'
)
model_filename = 'segm_model_v0.h5'
model.load_weights(model_filename)
y_pred = model.predict(x_val)
from utils import plot_imgsplot_imgs(org_imgs=x_val, mask_imgs=y_val, pred_imgs=y_pred, nm_img_to_plot=3)

The above script is basically importing the data, creating the model and instead of training it, we are predicting the labels by loading our saved weights. Below are the results :

This ends my semantic segmentation tutorial and what we’ve seen here is just a tip of the iceberg considering the wide range of applications semantic segmentation has, starting from medical imagery to self-driving cars.Thank you.