Malaria Detection using Deep-Learning

Original article was published on Deep Learning on Medium


Importing the Libraries

Let’s do some Data Exploration now:

  1. First, we will input our data and apply various transforms to it.

a) The dataset contains images that are of irregular shape. This will hinder the model training. So, we resize the image into the shape of 128 x 128.

b) We are also going to convert our data into tensors because it is a useful format for the training of the model using deep learning.

The beauty of PyTorch is that it allows us to apply multiple transformations by using very few lines of code.

Now, we will write a helper function to visualize a few images.

Let’s see a set of images for both classes.

For the sake of getting reproducible results, we need to set the seed. The reason for setting seed can be found here.

Now, let us divide our whole set of images into training, validation, and test sets. The training set is, obviously, for training the model, whereas the validation set is for making sure that the training is proceeding in the right direction. The testing set is for testing the model’s performance at the end.

We will try to train our model by using batches of images. Here, PyTorch’s DataLoader utility comes handy for us. It provides an iterable over the given dataset.

We will use the DataLoader to create our batches for training and validation. We need to make sure that the batches are shuffled internally while training. This is just to introduce some randomness into the model. We don’t have shuffle internally for validation set because we are just using to validate our model performance for each epoch.

Let’s try to visualize a batch of images now. We will write a helper function for this.

Let’s visualize a batch of images now.

Since our data is of images, we are going to train a convolutional neural network. If you got intimidated by listening to that, then you are not alone. I was also very scared when I first heard about CNN’s. But, to be frank, they are very simple to understand and some simpler to implement thanks to deep learning frameworks like Tensorflow and PyTorch.

You can get a brief understanding of the convolution operation by reading this article by Irhum Shafkat. CNN’s use the convolution operation in the initial layers to extract features. The final layers are normal linear layers.

We are going to define a Base class for the model which contains various helper methods. These methods can be helpful in the future if we try to solve a similar problem.

Now let’s define our main class which inherits the Base class. Let’s name it Malaria2CnnModel

Training a deep learning model is very time taking and exhausting if we use a CPU. There are many platforms like Kaggle and google’s Colab which offer free GPU compute to train models. The below helper functions help us to find if there is any GPU available on our systems. If yes, we can transfer our data and model into the GPU for faster computations.

This project was done on Kaggle which provides 30 hrs of GPU compute time every week.

We have defined a DeviceDataLoader class to transfer our model, training, and validation data.

Now, we will define our fit and evaluate functions. fit() is used to train the model, while evaluate() is used to see the model performance at the end of each epoch. An epoch can be understood as a step in the whole training process, which is a series of steps.

Let’s transfer our model to a GPU device.

We will evaluate the model to see how it performs on the validation set before training.

We get around 50% accuracy before training. For critical applications in the healthcare sector, this is very low. We are going to set the number of epochs and optimizer (torch.optim.Adam)to be used. We have also set the learning rate to be 0.001

We are going to define a few functions to plot losses and accuracies at the end of each epoch.

Now, let’s train our model by using the fit function.

By the end of the training, our model has improved a lot from the earlier 50% accuracy to get to about 95.54

Let’s plot the accuracies and losses for each epoch to come to an understanding of our model.

We will now write a function to predict the class of a single image. Then we will predict on the whole test set and check the overall accuracy then.

Predicting on a single image:

Now let’s try to predict on the whole test set:

We got some pretty good results here. 96% is a very good result but I think this can be still improved by changing the hyperparameters. We can use more numbers of epochs too.

We will stop here. We will continue in a separate blog from where we left off. We would be using different techniques like data-augmentation, regularisation, batch-normalization, etc. Thanks for staying with me till the end. Keep looking for the next post in a few days.

Future Work:

  1. We will try to apply transfer-learning techniques and see if it increases the accuracy further.
  2. We will try Image Segmentation techniques and Image Localization techniques to cluster the red globular structures and analyze them for evidence.
  3. We will try to use data augmentation techniques to restrict our model from overfitting.
  4. We will learn how to deploy our model in production so that we can showcase our work to people who do no understand code.

References:

  1. PyTorch’s Documentation
  2. The research paper by Nitish Srivastava and Jeff Hinton: http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf
  3. Jovian.ml’s lecture series.