Medical Image Classification with Transfer Learning



Introduction

With more powerful network architectures for image analysis being published every week, it’s hard to keep up with the most recent papers (even with resources like Arxiv-Sanity). Moreover, more complicated architectures generally require more trained parameters which result in longer training times. How can we harness the strengths of large neural networks without needing multiple GPUs, massive amounts of training data, and time? Transfer Learning is the answer!

What is Transfer Learning?

Transfer learning is defined as utilizing a pre-trained model from one problem to solve a different problem. Applying transfer learning is an art in it of itself and depends on the following variables:

  • Amount of Data Available
  • Computational Power
  • Complexity of Data
  • Similarity of Data to Other Problem

The general rule of thumb is to either:

a) freeze all the pre-trained model’s weights during training

or

b) freeze a subset of the pre-trained model’s weights

In either case, we generally always replace the top layer (i.e. the fully-connected layer) with a custom layer, specifically catered to the problem, and train that layer on our dataset.

Why?

In practice, training a large neural network from scratch requires significant data and compute power. In most cases, it will take you far too long to fully train a large neural network locally to converge even with techniques such as scheduling your learning rate, batch normalization, etc.

Our current understanding of how neural networks extract features from images is that as the data propagates deeper through your network, the more specific features the network can extract. So, the initial layers detect basic features such as edges, colors, and shapes whereas the deeper layers learn more data-specific features. Hence why we can use a pre-trained model to detect the basic features for us and then add our own layer to detect the data-specific features (Read Colah.io for more information).

Medical Image Classification

One field that benefits greatly from transfer learning is healthcare. Finding good quality medical imaging datasets is hard. Coupled with the fact that your datasets will generally be small (of order 10² — 10⁴), training a CNN from scratch that doesn’t have a high amount of variance is out of the question.

Let’s take a look at an example of applying transfer learning to solve the problem of diagnosing pneumonia based on an X-Ray image. The dataset was obtained from Kaggle which contains D = 5,863 images. As described by Paul Mooney (the dataset publisher):

For the analysis of chest x-ray images, all chest radiographs were initially screened for quality control by removing all low quality or unreadable scans. The diagnoses for the images were then graded by two expert physicians before being cleared for training the AI system.

Hence, we can assume that our dataset is of sufficient quality. Shown below are a few images from D:

Pre-processing

We see D is already split into training / validation / training sets with D_train =5126​, D_val=16, and D_test=631. With such a small validation set, we decided it would be better to merge the validation set into the training set and not use a validation set.

We define our target variable as y∈ {0, 1} since we are dealing with a binary classification problem.

At first, we wanted to see what the sample size of each class was:

We see that our classes are imbalanced, with roughly 85% being pneumonia-positive. This is expected as generally you don’t get an X-ray unless something is wrong or the doctor wants to confirm his initial diagnosis.

There are a number of ways to combat this issue. We could oversample the under-represented class, undersample the over-represented class, add weights to our loss function etc. We decided to weigh our loss function by computing the relative weights of our two classes and feeding that into our loss function during training:

https://gist.github.com/bd31e3ce470ec0e34783e6627c9b7452

In addition to gray-scaling and normalizing all of our images, we applied the following transformations to our images during training in order to minimize overfitting:

  • Horizontal Flips
  • Rotations up to 20 degrees
  • Random Brightness
  • Edge Highlighting

Initially, we were using Keras’s ImageDataGenerator object for our image augmentation but we found that it offered a limited amount of options. We found a great image augmentation library called imgaug that offered a wider selection of augmentation methods with a simple API:

https://gist.github.com/1db9ee1c852cd1cb228600af5fe87875

Model

We decided to use Google’s Inceptionv3 pre-trained model as the basis for our transfer learning due to its ability to automatically decide which kind of convolution to use while still maintaining the ability to capture both local and higher-level features. Shown below is the model’s architecture:

When it came time to construct the fully-connected layer, a naive approach would be to add a hidden layer or two with a suitable activation function: https://gist.github.com/02d8f57f2d0e4017c5688c80786d64cd

However, such a network architecture is prone to overfitting. Techniques such as dropout as proposed by Hinton have become industry standard. One alternative technique described in Lin et al’s paper entitled Network In Network proposes not using a fully-connected layer at all but the use of a global average pooling layer (GAP). The GAP layer is used to reduce the spatial dimensionality of a given tensor without varying the depth by taking the average of all spatial components. For instance, say we have a tensor T of rank-3 with dimensions (h, w, d). If we apply the GAP layer, we preserve the rank of T but has new dimensions (1, 1, d):

Source: Alexis Cook

This technique has several benefits:

  • No additional weights needed to be trained, reducing overfitting
  • The summing of spatial information allows network to be more robust to spatial translations within the input
  • We generate one feature to map to each class, enforcing correspondences between feature maps and classes

Hence, our final fully-connected layer architecture looked as follows: https://gist.github.com/5ed9261408db9dd76c0a73210e3f69de

We added batch normalization in order to reduce training time and to increase network stability. In combination with dropout, these two techniques help further reduce overfitting which is always an issue to be concerned with when using deep learning techniques with a small dataset.

Evaluation

After many hours of hyperparameter tuning, we came to the following results:

Even though our model is by no means perfect, we were still able to achieve high recall which we inevitably want; we would rather be careful and have a higher number of false-positives than false-negatives. Due to the imbalanced distribution in our dataset, measuring accuracy is not a good way to evaluate our model. Based on our ROC curve and F1-score, we were able to train a good model.

Given more time, we would try training some layers of the Inception model to better suit our data. We would also try using some sort of automated process for selecting hyperparameters whether that be via random search or AutoML; Random Search for Hyper-Parameter Optimization for further reading.

Interestingly, we found that model training performance took a massive hit even though we were able to load the entire training set into memory. Many people have found that fit_generator is substantially slower than fit. After research, we concluded that there was no work-around as we needed to augment each batch of images independently and the only way to do that would by using a generator to create the training batches. Hopefully in the future, this will be resolved.

Conclusion

Transfer learning is a powerful tool in any machine learning practitioner’s toolbox. Being able to harness the power of pre-trained DNN’s for your own problem is profound; even Andrew Ng said “Transfer Learning will be the next driver of ML success.”

I hope this article has given you some glimpse into how to apply transfer learning to your own problems!

Source: Deep Learning on Medium