My foray into FastAI by building a Classification Model for Malaria diagnosis

Source: Deep Learning on Medium

Go to the profile of Eumie Jhong

I love learning about Deep Learning but to be honest, I get intimidated on where I should begin since I am a relative “n00b” to the scene. Thankfully, FastAI is an accessible library and platform that provides awesome tutorials so that beginners like myself can get started on building models quickly! And best of all, using FastAI also gets me working with real data to make models that can make a difference, such as diagnosing life-threatening diseases like Malaria.

FastAI can run on Google’s Colab notebook that provides free Tesla K80 GPU for up to 12 hours at a time — this means that anyone can create models on any computer with internet access in less time!

I will use the ResNet34, ResNet50, and ResNet101 pre-trained models because I want to see if experimenting with different layers of neural networks will improve the accuracy of my model.

For the first step, I load the dataset to the Colab notebook — using the ‘google.colab import drive’ function and select files from my Google Drive.

After loading my dataset, I imported the Fast AI vision and metrics and chose the batch size for my training dataset to 32.

In the next step, I declare the binary classes — Parasitized and Uninfected, and establish the Path for the files to get started on working with the cell images.

Since FastAI cannot work with image files directly, I use the function — ImageDataBunch to show the sample batch we will be subsequently training.

I use the Data.Show_Batch to show the images of the training batch.

Now I begin to train then validate the model using ResNet34, which is a model that has been already been pre-trained based on a large image dataset, and a great base model for us to start off with.

Next, I plot the learning rate finder and fit the model using the learn.fit_one_cycle function through the 1 cycle training policy. In order to find the learning rate with the smallest loss rate, I see that the sharpest descent appears to be between 1e-03 and 1e-02. I want to apply the discriminative learning rate to the pre-trained model just to see what happens and end up with an accuracy rate of 96.6% — which is pretty good.

However, I want to fine-tune this model by using the unfreeze method then apply the discriminatory learning rate to further train the lower-layers of the ResNet architecture, which I will do below.

Once I fine-tune the model, I get an improved accuracy rate of 96.9%!

I also want to plot the images that had the highest losses below. I will run a confusion matrix to see how much of the sample data are actually being diagnosed correctly/incorrectly. It appears that 186 were being misdiagnosed out of 5373 samples, which approximates to a 3.5% error rate with a 96.5% accuracy rate, which is pretty good — but I wanted to see if I can further improve the model by using different ResNet layers of 50 and 101.

I will apply the same procedure of running the pre-trained models of ResNet 50 and 101 and later using the unfreeze and the discriminatory learning rate methods to fine-tune these models, respectively. Here are their Confusion matrices for each respective model:

The first matrix is the fine-tuned ResNet50 model using the unfreeze method. The overall error rate of misdiagnosing the sample data is about 2.7% (or 145 images out of the 5373) — with an even higher accuracy rate of 97.3%.

Next is the confusion matrix of the fine-tuned ResNet101 model using the unfreeze method. The overall error rate of misdiagnosing the sample data is about 2.5% (or 134 images out of the 5373) with an even higher accuracy rate of 97.5%.

Overall, running all three different ResNet models takes relatively little effort, but the improving accuracy rates seem to point that trying different layers can reap benefits in yielding better results, which we always want for our models. Next time I want to try other fine-tuning methods, such as Progressive Resizing, to improve my model even more!

My notebook can be found here

Thanks for the read!