Deep Learning for Medical Image Analysis

Source: Deep Learning on Medium

Go to the profile of Pawan Jain

Beginner’s tutorial to Implement transfer learning using vgg16 architecture in pytorch on OCT Retinal Images

Deep learning has the potential to revolutionize disease diagnosis and management by performing classification difficult for human experts and by rapidly reviewing immense amounts of images. Let’s try to implement the same concept

Well, if you are completely unaware of terms like an epoch, CNN and other basic terminology related to deep learning and pytorch then I recommend you to check this first.

About Dataset

This dataset of Retinal OCT Images is obtained from Kaggle datasets. Well, we have four classes of images CNV, DME, DRUSEN and NORMAL

  • (Far left) choroidal neovascularization (CNV) with a neovascular membrane (white arrowheads) and associated subretinal fluid (arrows).
  • (Middle left) Diabetic macular edema (DME) with retinal-thickening-associated intraretinal fluid (arrows).
  • (Middle right) Multiple drusen (arrowheads) present in early AMD.
  • (Far right) A normal retina with preserved foveal contour and absence of any retinal fluid/edema.

About OCT

Optical coherence tomography (OCT) is an imaging technique that uses coherent light to capture high-resolution images of biological tissues.

  • OCT is heavily used by ophthalmologists to obtain high-resolution images of the eye retina. The retina of the eye functions much more like a film in a camera.
  • OCT images can be used to diagnose many retina related eye diseases.
  • In some cases, OCT alone may yield the diagnosis (e.g. macular hole). Yet, in other disorders, especially retinal vascular disorders, it may be helpful to order additional tests (e.g. fluorescein angiogram).

Exploring the Dataset

Let’s try to look at the number of images in each category and the size of the images.

We get a data frame with counts of images of each category in the test, train and validation folders by which we can get some basic intuition about our dataset

  • Just 9 images in the validation dataset (Very few)
  • We have approximately 37k train images of CNV, 26k of NORMAL and 11k and 8k resp of DME and DRUSEN

Image Preprocessing

To prepare the images for our network, we have to resize them to 224 x 224 and normalize each color channel by subtracting a mean value and dividing by a standard deviation. We will also augment our training data at this stage. These operations are done using the image transforms, which prepares our data for a neural network.

This helps us to visualize the range of width and height of images in our training set

When we use the images in the pre-trained network, we’ll have to reshape them to 224 x 224. This is the size of images and is therefore what the model expects. The images that are larger than this will be truncated while the smaller images will be interpolated.

Data Augmentation

Due to the limited number of images, we can use image augmentation to artificially increase the number of images “seen” by the network. This means for training, we randomly resize and crop the images and also flip them horizontally. A different random transformation is applied to each epoch (while training), so the network effectively sees many different versions of the same image.

  • All of the data is also converted to Torch Tensors before normalization.
  • The validation and testing data are not augmented but are only resized and normalized.

Data Iterators

To avoid loading all of the data into memory at once, we use training DataLoaders.

  • First, we create a dataset object from the image folders, and then we pass these to a DataLoader.
  • At training time, the DataLoader will load the images from disk, apply the transformations, and yield a batch.
  • To train and validation, we’ll iterate through all the batches in the respective DataLoader.

One crucial aspect is to shuffle the data before passing it to the network. This means that the ordering of the image categories changes on each pass through the data (one pass through the data is one training epoch).

Using Pre-trained Model (VGG-16)

The idea behind pre-training is the early convolutional layers of a CNN extract features that are relevant for many image recognition tasks. The later, fully-connected layers, specialize in the specific dataset by learning higher-level features.

Therefore, we can use the already trained convolutional layers while training only the fully-connected layers on our own dataset. Pre-trained networks have proven to be reasonably successful for a variety of tasks, and result in a significant reduction in training time and usually increases in performance.

VGG-16 Architecture

The classifier is the part of the model that we’ll train. However, for the vgg, we’ll only need to train the last few layers in the classifier and not even all of the fully connected layers.

Classifier part of vgg16 with 1000 out_features
  • We freeze all of the existing layers in the network by setting requires_grad to False
  • To build our custom classifier, we use the nn.Sequential() module which allows us to specify each layer one after the other
  • The final output will be log probabilities which we can use in our Negative Log-Likelihood Loss (NLLL).
  • Simply move the whole model onto the GPU

I have skipped the codes of above parts you can check them at either on my Github repository or Kaggle Kernel

135,310,404 total parameters.
1,049,860 training parameters.

Mapping of Classes to Indexes

To keep track of the predictions made by the model, we create a mapping of classes to indexes and indexes to classes. This will let us know the actual class for a given prediction.

[(0, 'CNV'), (1, 'DME'), (2, 'DRUSEN'), (3, 'NORMAL')]

So our complete model is pre-trained vgg + custom classifier. Well, it’s a pretty long model and it’s not possible to post that here. Still, I manage to take a snapshot from our last custom module by using torchsummary

Custom classifierof our model
  • 128 is our batch size, You might need to decrease the batch_size if this is not fitting on your GPU
  • We have 4 classes to classify and that is very clear in our last layer

Training Loss and Optimizer

Loss (criterion): keeps track of the loss itself and the gradients of the loss with respect to the model parameters (weights)

Optimizer: updates the parameters (weights) with the gradients

  • The loss is the negative log-likelihood and the optimizer is the Adam optimizer.
  • The negative log-likelihood in PyTorch expects log probabilities so we need to pass it the raw output from the log softmax in our model’s final layer.


For training, we iterate through the train DataLoader, each time passing one batch through the model. We train for a set number of epochs or until early stopping kicks in (more below).

  • After each batch, we calculate the loss (with criterion(output, targets)) and then calculate the gradients of the loss with respect to the model parameters with loss.backward(). This uses auto differentiation and backpropagation to calculate the gradients.*
  • After calculating the gradients, we call optimizer.step() to update the model parameters with the gradients. This is done on every training batch so we are implementing stochastic gradient descent (or rather a version of it with momentum known as Adam).
  • For each batch, we also compute the accuracy for monitoring and after the training loop has completed, we start the validation loop. This will be used to carry out early stopping.
  • Early stopping halts the training when the validation loss has not decreased for a number of epochs. Each time the validation loss does decrease, the model weights are saved so we can later load in the best model.
  • Early stopping is an effective method to prevent overfitting on the training data. If we continue training, the training loss will continue to decrease, but the validation loss will increase because the model is starting to memorize the training data. Early stopping prevents this from happening
  • Early stopping is implemented by iterating through the validation data at the end of each training epoch and calculating the loss. We use the complete validation data every time and record whether or not the loss has decreased. If it has not for a number of epochs, we stop training, retrieve the best weights, and return them. When in the validation loop, we make sure not to update the model parameters.

Training Results

We can inspect the training progress by looking at the history.

As expected, the training loss decreases continually with epochs. There is not a massive amount of overfitting, likely because we were using Dropout. With the divergence in losses, there is not much more to gain from further training.

Validation loss showing abnormal behavior due to less number of validation images

As with the losses, the training accuracy increases while the validation accuracy plateaus in general. The model is able to achieve around 79% accuracy right away, an indication that the convolution weights learned on Imagenet were able to easily transfer to our dataset.

Note: Here we have only 9 images ofeach class in our validation dataset

Testing Our Model

After the model has been trained to the point on no more improvement on the validation data, we need to test it on data it has never seen. For a final estimation of the model’s performance, we need to use the holdout testing data. Here, we’ll look at individual predictions along with loss and accuracy on the entire testing dataset.

The above function makes predictions on a single image. It will return the top probabilities and classes.

img, top_p, top_classes, real_class = predict(random_test(), model)
top_p, top_classes, real_class
(array([0.5736144,0.26276276,0.15282246,0.01080029], dtype=float32), 

Model Investigation

Although the model does well, there are likely steps to take which can make it even better. Often, the best way to figure out how to improve a model is to investigate its errors (note: this is also an effective self-improvement method.)

Well, it seems like our model is working good on the test set. Let’s try to get more intuitions about them.

This function displays the picture along with the topk predictions from the model. The title over the image displays the true class.

These are pretty easy, so I’m glad the model has no trouble!


  • We were able to see the basics of using PyTorch as well as the concept of transfer learning.
  • And how we can use such image sensing techniques in the field of medical science

The complete code is available on my GitHub repository