Original article was published on Deep Learning on Medium
The images are slightly altered compared to the original images before preprocessing.
Great! We are done with the preprocessing step and ready to develop our models using transfer learning.
I used transfer learning for model development. I tested the performance of three different models with varying depths: VGG-16⁹, VGG-19⁹ and ResNet-34¹⁰. These three models were all trained on the ImageNet16 dataset (a 14 million image dataset with 1000 different classes), receiving state of the art accuracies.
The first step is to decide whether you want to train on a GPU or not. I have an Nvidia GeForce GTX local GPU, so I am going to allocate the model to the GPU.
Note: setting up the GPU can be a very complicated process if you don’t already have CUDA on your local device. I highly recommend you read this article for a very simple set up process. It sets up a CUDA environment for Tensorflow but it works with PyTorch the same way.
Then, I have to “freeze” all of the weights in the model. If you want to learn more about transfer learning in PyTorch refer to this article. This means that I make the weights untrainable so they are not updated via gradient descent. So, I loop through the model parameters and set requires_grad to false.
Here is the code for this:
Finally, I create a sequence of trainable layers to replace the final classification layer in the model. The original model has 1000 classes, but this is a binary classification problem so I need to end up with a two class output.
I used the Sequential model from the torch.nn library. The good thing about this model is that it allows the input of a dictionary. So, I can label each layer with a name to keep track of the various layers. I chose to use 4 layers to allow sufficient trainable parameters to learn from the data. Of course, this is a hyperparameter that can be tuned to any extent. Additionally, the number of neurons in each layer is another changeable hyperparameter. Oftentimes, you may see powers of two being used for number of neurons, so I stuck to that. Finally, you may notice that I used dropout regularization between each layer. This prevents the model from overfitting on the training data, which would cause lower accuracies on the validation and testing data.
As I said earlier, I compared three state of the art models. The process of freezing weights and replacing the final classification layer is similar for all three models, so I am not going to show that here.
Now our model is ready for training.
Unlike the Tensorflow and Keras libraries, we have to write our own training loops in PyTorch. But, this is actually an intuitive process, so let’s begin.
I am going to define a method called train that takes in the arguments for model, train loader, validation loader, optimizer, loss function, number of epochs, and potentially, a learning rate scheduler. The function will train the model and return four lists: lists for the training accuracy, training loss, validation accuracy and validation loss.
The method starts like this:
We then loop through the number of epochs (number of times the model passes through the data), and initialize variables to keep track of the loss and accuracy metrics (this will be important later in the article).
Then, we loop through the images and target classes in the training data loader, find the model’s prediction, compare the model’s prediction with the ground truth (computing the loss function) and update the weights accordingly.
I also wrote a smaller loop to calculate the number of correct predictions and the total predictions, so we can calculate the accuracy.
Next, we set the model to validation mode and compute the validation loss and accuracy. The process is similar to that of the training step except that the model should not learn in this step. This means that we don’t update the weights for this data. I created a helper function called validation to calculate the validation loss and accuracy. I do not show it here but you can see the full code on my GitHub.
We print out the loss and accuracy metrics after each epoch and append these metrics to the lists. You may be wondering why I stored all of the metrics in a list. This will come in handy when we view the loss and accuracy curves.
Now we can use the method to train our model.
For this trial, I used binary cross entropy loss and Adam optimization as these are standard for binary classification tasks. I also trained the model for 30 epochs because previous attempts showed that the loss stabilizes near this point.
This was the output after 30 epochs:
I trained VGG16, VGG19 and ResNet-34 on a varying number of epochs, batch size and learning rate schedulers. These are the best models received for each pretrained network.
Before we get too far on the analysis of the models’ performance, let’s look a bit closer at cyclic learning rate scheduling. According to Smith (the author of the original cyclic learning rate paper), the idea behind this type of learning rate scheduling is to allow the learning rate to cyclically vary between a certain range, instead of systematically increasing or decreasing it¹¹.
I used cyclic learning rate scheduling because it was easier to tune and showed better results than other schedulers that I tried.
Alright, back to the model analysis.
The VGG16 model had the highest validation and testing accuracy after 30 epochs while the VGG19 model had the highest training accuracy. The ResNet-34 model performed the worst on all the sets. The VGG16 model was the only model that did not overfit, and this is probably because the model is shallower, so it cannot fit such complex functions. So, the shallower networks generally performed better because the data set was very small. The ResNet performed poorly in general because it is deeper than the other networks, so it probably required a longer training time or a larger batch size.
One thing I noticed while training the various models is that the validation/testing accuracy was sometimes higher than the training accuracy. This is most likely due to the small amount of data and the validation split size. The total number of images is around 750, and I used a validation split size of 0.1 meaning there were only about 75 images for validation. So, if even a few more images were classified correctly in a certain epoch, the validation accuracy would increase more than that of the training accuracy. This is a limitation for this project because there isn’t enough data to truly compare the validation accuracy to the training accuracy. However, we can still make a distinction between model performances because the validation accuracy can be compared between models, as the number of images in the validation set stays the same.
Although the results show that the VGG16 model performed the best, the testing dataset was quite small (only 74 images) so it is not enough to say that the VGG19 model is better without additional analysis. So, we will look at further analysis in the evaluation section.
Evaluation on Testing Data
Remember how I said storing the accuracy and loss metrics would be important? Now we can plot those values against the number of epochs and visualize how they change as the training process progresses.
A trend you may notice in the plots is that the validation plots are a lot more noisy than the training plots. This is because the validation set is only 74 images while the training set is around 550 images. So, a few more mistakes in the validation set can cause loss and accuracy to be much worse than a few mistakes in the training set. But overall, the loss does decrease over time for each model. Also, if you look at the training curves for all three models, both the ResNet-34 and VGG16 look as if they are flattening towards the end, while the VGG16 model seems as if it will continue to improve. In future work, I plan on training the model for more epochs to see if the VGG16 would continue to improve.
We can continue to evaluate our models based on the Receiver Operating Characteristic (ROC) curve which is a plot of the false positive rate against the true positive rate.