Source: Deep Learning on Medium
The course Fast.ai is an amazing free online course to obtain hands-on deep learning experience. I have started watching the lectures long time ago and never finished an actual project with it as I always wanted to only use it for a data related to my PhD research, which is on chest CT scans; It is usually not really easy to obtain a large well kept and labeled data set of Lung CT scans to start playing with CNNs. Finally this time I decided to keep up with the course and follow other folks’ advice on starting with a small project, complete that, and then move forward to harder projects.
This blog is about a hands-on experience using transfer learning technique with ResNet to classify types of flowers. I meant to show how easy it can be to get started on applying deep learning, and help others get motivated on this. Additionally, when I first started the HWs of the course, I ran into several errors before even being able to read data, however seeing others being able to overcome the errors and build an actual project, helped me to stay motivated; So I want to contribute to this as well.
Although I am not going to discuss the theory behind CNN, transfer learning, ResNet, etc., here is a brief introduction on ResNet which indicates why I will be using this specific network as a backbone for my project:
ResNet was inspired by the VGGNet and proposed a residual learning approach to ease the difficulty of training deeper networks which can lead to higher performance and less overfitting.
I have obtained the image data from Kaggle’s Flower Recognition. The images are divided into five classes: daisy, tulip, rose, sunflower, dandelion.
- Data Preparation
Although like many of you, I am also very interested to build my own CNN by developing all the code by myself line by line with what I have learned from building Neural Network models (forexample with Tensorflow). However, for now, I want to stick with Fast.ai library to learn PyTorch as well. (Note: In this blog I will only discuss high level application of PyTorch through Fast.ai libraries, and will not get into the details of the PyTorch framework)
I use the fantastic and free Google Colab and to install the library and import required libraries I run:
!curl -s https://course.fast.ai/setup/colab | bash
######### Or the following:
#!pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cu92/torch_nightly.html
#!pip install fastai
There are times that the course developers update the libraries and suddenly a command changes and my code gives error. So it is a good practice to always run the above code to keep an updated library. You can also check your version of library and compare with the latest version on the course’s Github or its forum.
!pip show fastai
Import the library:
from fastai import *
from fastai.vision import *
from fastai.metrics import error_rate
Download the Flower data in your Google Drive and mount the drive so that you can access the data in colab:
from google.colab import drive
root_dir = "/content/gdrive/My Drive/"
base_dir = root_dir + 'deep_learning_data/'
path = base_dir + 'flowers/'
Read in the data and take a look at it:
So I used Fast.ai library to read in the images, normalize data, and resize the images into 224 by 224 pixels. One thing worth notice is that in transfer learning (i.e. when using a pre-trained network)we should normalize the images using the normalization applied during the model development (i.e. for Resnet we should use normalization from the image net data).
tfms = get_transforms(do_flip=False)
# Read the data. The following command will divide the data into train and validation set. bs: determines the batch size during training.
# reducing the bs and the size will speed up the training.
data = ImageDataBunch.from_folder(path, valid_pct=0.2,
ds_tfms=get_transforms(), size=224, bs = 32, num_workers=4).normalize(imagenet_stats)
# look at the classes:
print(data.classes, data.c, len(data.train_ds), len(data.valid_ds))
# show some images
data.show_batch(rows= 3, figsize=(7,8))
2. Perform Classification
Now the fun starts! Let’s classify images!
With fast.ai library, you can do it in two ways: 1) use the pre-trained architecture of a model and add a few layers to it and train its weights, 2) trainand fine-tune the whole model to adjust it better to our dataset.
1. Start with the same architecture of the pre-trained model and train the model with new flower data for 4 epochs
learn = create_cnn(data, models.resnet34, metrics=error_rate)
2. Unfreeze the model, run a script to find an optimal learning rate, and fine-tune the whole model.
The optimal learning rate was obtained from the following plot where the lr_find script has calculated loss for various learning rates. We see that after 1e-4, the loss starts to increase, so we choose the range (1e-6,1e-4 ) which means that earlier layers of the network start getting trained with less learning rate and as it goes through deeper layers, the learning rate increases up to the maximum value of 1e-4. Basically what we do here is that we modify the earlier layers less than the deeper layers, as the deeper layers are the ones who look into more specific properties of the image and we want them to get more clever to understand flowers better rather than the images they were trained on (e.g. ImageNet).
My experience on the flower data was that, it took much longer to train the model. I decreased the batch size to 16, and the size of the image to 100 at first to just get an idea of the performance.
3. Performance Evaluation
We need to understand where our model is performing poor. We can look at the images that were miss-classified.
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
when the model’s layers were frozen and trained in 4 epochs, My accuracy was around 93%. I looked at some examples where the model did not work well, and I see that there are times that the model confuses rose and tulip, or confuses dandelion with sunflower. I also figured out that some images in the data that do not have any flowers at all! and I failed to remove these useless images fro my dataset before training. So the lesson learned was: make sure of the quality of your data!
Unfreezing the model and training it at 2 epochs allowed to slightly improve the performance by 1%.
Save the model or export it for future use:
learn.export() # will save the model with the name: 'export.pkl' in the working directory.
# or run