Source: Deep Learning on Medium
Set up the environment
First thing’s first. Using GPU would make the task easier, so I tried using both Google Colab and Kaggle Notebook to build this model. But since I like to be able to edit (add, change, move around, or delete) my data, I prefer using Google Colab since the data is stored in my Google Drive.
Load and check the data
We have 12 classes of images we would like our Pytorch model to classify. It is a good idea to divide the data into three groups — training data in which we train our model, validation and test data in which we will use to ensure our model does not overfit.
Let’s first check the distribution of these data.
Ok, so for each training, we have 600 training images (one short for
goat), 54 validation images (one extra for
goat, which is likely because that one image got misplaced), and 54-55 test images.
Looks good to me.
Let’s see a sample of those images.
Looks good to me as well. We can also see that we are incorporating drawing images as well (cause dragons don’t exist and all).
Creating the model using Pytorch
One of the good things of Pytorch (as well as other machine learning/deep learning frameworks) is that it gives us simplified boilerplate codes. One of which is loading train-val-test data.
Now to build the model.
I will be using the pretrained Resnet34 model to transfer its learning to build this classification model. I have also tried other pretrained models like the Resnet101 and VGG 19 with BatchNorm, but the Resnet34 gives me a pretty good performance so I’m going with that. The Resnet34 requires the input images to have width and height of 224.
Full model architecture is below.
I won’t use anything very complex here. Just two additional FC layers with 512 neurons each, and one output layer with 12 neurons (one for each zodiac sign class, of course).
Training the model
Now comes the first exciting part, training the model.
We simply need to iterate over the train data loader whilst doing (a) a forward and backward pass on the model, and (b) measure the current/running performance of the model. I choose to do point (b) every 100 minibatch.
I chose to train the model over 7-15 epochs. You’ll see why in this chart below.
Over every 100 minibatch passes, the model definitely improves over time on the training dataset. But when we look at the performance on the validation dataset, it does not seem to improve that much (the accuracy seems to improve by a tiny bit over time though, but I don’t think it is enough).
But when we take a look at other models with different architecture, the same thing happens.
The second model is with a Resnet50, pretty much the same with the first except I changed the learning rate from 0.001 to 0.003. The third is using VGG 19 with Batch Norm and learning rate of 0.001.
Three different models+parameters tells the same story — Accuracy performance on validation does not improve significantly after multiple epoch as much as training (especially for the last two models).
We do not concern ourselves too much with the models’ loss, as it is a measure of just how ‘confident’ the model performs, and we are focusing more on the accuracy.
Test the model
Let’s see if the models are actually good, or breaks apart meeting the test dataset.
- ResNet (lr 0.001) — loss: 0.355 acc: 90.5%
- ResNet (lr 0.003) — loss: 0.385 acc: 90.6%
- VGG 19 with Batch Norm — loss: 0.586 acc: 90.8%
Looking at the accuracy, they are pretty much the same. For loss, the ResNet with 0.001 lr reigns supreme. These numbers are almost the same as our training and validation, so we can say our model does not overfit (or at least we stopped it just before it did with our small epoch)and is working pretty good to classify Chinese zodiac signs.
I am curious, as to which images it is having the most problem on.
Loving the numbers in that confusion matrix.
We can see that the model rarely mistakes a
goat, but if it does, it is always with
We can also see that the model had a bit of difficulty distinguish
dragon with other zodiac signs. Most notably with
oxes (horns) and