Original article was published on Deep Learning on Medium
Classifying Natural Scenes using Transfer Learning
I went from knowing little about artificial intelligence to implementing a number of different machine learning (ML) algorithms in just 6 weeks after participating in a course called Zero to GANs created by Jovian.ml in collaboration with freeCodeCamp.org.
In the course, we used PyTorch, a beginner-friendly ML library that provides many useful tools for creating, training and evaluating models.
The final assignment of the course was to create an ML project completely by myself and this article is the result of my work.
Picking a Dataset
During the course I learned about numerous ML algorithms, for example linear regression, deep neural networks (DNNs) and generative adversarial networks (hence the name Zero to GANs).
I knew from the start that I wanted to create a project dealing with image classification, and that is how I stumbled upon the Intel Image Classification Dataset. This dataset contains images of natural scenes, which need to be classified into the following six categories:
Dealing with the Data
Before changing the data, it was important for me to get to know the dataset a little better, so that I know what I have to work with.
Firstly, I looked at the size and contents of the dataset. There are around 14,000 RGB images in the training and 3,000 in the test dataset and each image is of size 150×150 pixels. Here are a few examples from the dataset:
Next, I graphed the distribution of the images among the classes:
As you can see, the distribution is fairly balanced which makes it easier to create a model that is not biased towards any of the classes.
In order to get a higher accuracy during the training phase, it is advised to apply transformations to the input data, this is called data augmentation. It is a useful technique, because it makes the model “think” that there is more data available for training.
For example, if the model receives the same image twice, but the second time it is horizontally flipped, the model might produce an entirely different output, because it views the image as a different one.
Training with data augmentation can make the model generalize better, that is, it will probably perform better on data it has never seen before.
I applied several transformations to the training dataset, like randomly converting the image to grayscale, randomly rotating the image, randomly erasing a part of the image and a lot more.
Here is an example image before and after applying transformations:
The test dataset does not have to be modified, because it is used to evaluate the model, therefore I only converted the test images to PyTorch tensors, so that I could feed them to the model.
You might have noticed that I have not mentioned using a validation dataset. That is because in order to improve upon the accuracy of the model, I used the test dataset as the validation dataset too. Since the test dataset is not that big, this was an easy way to increase the number of images in the training dataset.
After dealing with the data, it was time to choose an appropriate model for training. There are lots to choose from, but during week 4, we learned about convolutional neural networks (CNNs), which are said to be one of the best deep learning algorithms for image classification. For this reason, I decided on using some convolutional layers in my model.
During the course, we also learned about transfer learning. The point of transfer learning is to use network architectures that are already proven to work.
I tried many different architectures including ResNet34, Resnet50 and ResNeXt-101 32x8d. One of the advantages of using these models is that they are already trained on the ImageNet dataset, which contains over 10 million images, and if we want to, we can use the weights of the pre-trained network as a starting point instead of using random values. This way, the network might learn faster, because the first few convolutional layers are not going to change much, since the low-level features they are able to recognize are general.
Training the Models
To train the models, I used the One Cycle Policy, which tries to avoid overfitting by changing the learning rate after each epoch and thus hopefully reaching a steeper local minimum. Here is an image of how the learning rate might change during one cycle:
Training a model requires a bit of experimentation too, because there are some hyperparameters that need to be tweaked, and all of them can affect the outcome of training. These hyperparameters include the maximum learning rate, the number of epochs, gradient clipping, weight decay, etc.
It is also important to choose a good optimizer and loss function in order to decrease the loss of the model efficiently. I found Adaptive Moment Estimation (Adam) to work well as the optimization algorithm and for the loss function, I chose cross entropy.
I got the best result with a pre-trained ResNet34. The highest test accuracy that I could achieve was 92%, which I think is good considering that this is my first ML project.
To get a sense of how the model performs on the test dataset, here are some pictures with their original and predicted labels:
And an example that fools the network:
I think that it is actually quite hard to decide whether this image belongs to the buildings or the street class. If I had to guess, I would say that these two classes are the reason for most of the misclassifications, because some of the pictures are so similar.
Let us see whether my guess is correct using some graphs:
After seeing these graphs, I suspect that the model might misclassify the glacier and mountain classes more often than the buildings and street classes.
Now it is clear where the model fails the most.
I am satisfied with the results of my project and I am glad that I took part in the course, because I learned so much during such a short period. While making this project, I could experience the power of CNNs, ResNets and transfer learning.
The great thing about PyTorch is that most of the code is reusable, so this project could easily be changed to classify, for example medical images.
If you are interested in the code of my project, you can view it here.