Original article was published by Javier Herbas on Deep Learning on Medium
Deep Learning Neural Network: Complex vs. Simple Model
Recently I wrote a blog titled “Pneumonia Detection From X-ray Images Using Deep Learning Neural Network” where I presented the results of what I chose to be the best out of 15 different model architectures that I created to solve a binary classification problem. For the readers not familiar with this, what it means is that my model will predict only a “0” or a “1”. A “0” corresponds to NORMAL or NO PNEUMONIA, and a “1” to PNEUMONIA.
Before I continue, don’t worry if you have not read my previous blog, you won’t need it to understand this one as this will have all the necessary information to compare as the title says, a complex to a simple model.
On my previous blog I described how I started with a simple 2-layer Convolutional Neural Network (CNN) model using minimal parameters to account for overfitting, and ended up with a more complex architecture made out of 5 convolutional blocks, each one, with dual layers and with some hyper-parameters to tune the model.
This last model I called it Model_15, which resulted with a validation and Test accuracy of 94.83 and 90.84% respectively. Really good accuracy values, but as I explained on my previous blog, accuracy alone can be misleading if you are dealing with unequal number of observations in each class (unbalanced) or if you have more than two classes in your dataset, and if you read my last blog, you already know that the X-ray dataset was very unbalanced, therefore, I should also be looking at “Precision” and “Recall”.
Now, I only briefly described these two metrics but without using any equations, mathematical background, or explaining their importance when trying to solve this particular health related problem. What I will try to do now is to prepare a context with some basic theory so that the reader can understand how with some more time I kept on testing my different models using “Confusion Matrices” and found out that one of my simplest architectures beat the complexity of Model_15.
First of all, it is important to understand that in general, between these two metrics (Precision and Recall), Precision tends to be more important than Recall. Why? because a higher precision means that an algorithm returns more relevant results than irrelevant ones, and high recall means that an algorithm returns most of the relevant results (whether or not irrelevant ones are also returned). But this might vary depending on how you want to use these variables, and this case of X-ray image classification is one of those exceptions where Recall will be more important, but I will explain this later on. For the time being let’s try to understand what a Confusion Matrix is first.
A Confusion Matrix is a technique that summarises the performance of a classification algorithm by giving you a better idea of what the model is getting right and what types of errors is making, in other words, it shows the ways your classification model is confused when it makes predictions (Figure 1).
Figure 1 is a basic example of a two-class problem Confusion Matrix, where the number “1” represents the positive (having the disease) and the number “0” the negative (not having the disease). So let’s decipher that simple matrix:
- The target variables has two values: positive or negative
- The rows represent the actual values
- The columns represent the predicted values
Not too complicated really, but now to the crucial part of the matrix which is understanding the following four concepts:
- True Positive (TP): The predictive value matches the actual value, meaning that the actual value was positive and the model predicted a positive value.
- True Negative (TN): The predicted value matches the actual value, meaning that the actual value was negative and the model predicted a negative value
- False Positive (FP): The predicted value was falsely predicted, meaning that the actual value was negative but the model predicted a positive value.
- False Negative (FN): The predicted value was falsely predicted, meaning that the actual value was positive but the model predicted a negative value.
If at this point you don’t understand the previous concepts, it’s alright, I have added Figure 2 to try to make them clearer, so hopefully with that and the case that will be discussed on this blog, should be enough.
Now that we know what the confusion matrix tells us, lets move to the dual concept of Precision and Recall and why they are important, but lets do it through an example where we assume the values from Figure 3:
Before we start, look at the data distribution and how unbalanced this dataset is with a 20:1 ratio between the TP and TN. With that in mind, let’s calculate the accuracy and discuss the results:
Not bad, an accuracy of 94.46%. But what is that number really telling us? Is it telling us that the model can predict 94.46% of the time when someone is sick? or is it telling us that it can predict 94.46% of the time when a patient is not sick? the reality is that you can’t really use this metric because the data is unbalanced and therefore open to being misinterpreted, hence, why we need to look at Precision and Recall.
The Precision will tell us how many of the correctly predicted cases actually turned out to be positive:
And Recall, tells us how many of the actual positive cases we were able to predict correctly with our model:
The Precision and Recall ended up being even better than the accuracy with values of 97,09 and 99,01% respectively.
Now, let’s stop for a second and remember my previous comment about Precision vs Recall, where I mentioned that Precision is more important than Recall when you would like to have less FP in trade off to have more FN. Meaning, getting a FP is very costly, and a FN is not as much. But, in the medicine field, Recall is more important because you don’t care about the FP as you want to hit every single positive case.
The easiest way to understand this statement if by analysing how Recall is calculated using the Confusion Matrix (Figure 3). In a simple way, Recall basically is the relationship between the cases where the model predicted a patient being sick and the patient was indeed sick, by the addition of that same value with the cases where the model made a mistake and predicted a patient being healthy but in reality the patient was sick. If you think about it, you as a doctor want a high recall, even if this means that you will raise a false alarm because the actual positive cases should not go undetected, and you as a patient will be happier if the doctor tells you that you are sick but you aren’t, instead of being told that you are not sick, but it turns out that you are sick. An hence, when it comes to health, Recall overtakes Precision.
Now, at this point you hopefully got my point, so lets compared the two model architectures (Figure 4 and 5) and look at their metrics
As Figure 4 and 5 have illustrated, neither model is actually a fancy one, they are both using basic parameters with the difference that Model_15 processing time is about three times longer to that of Model_5, which is mainly due to the dual convolutional layers in each of the 5 convolutional blocks, compared to 1 that Model_5 has per block. Furthermore, Model_15 has a Dropout on the fully connected layers to account for the overfitting which can cause some more processing delays.
Let’s compare both of these models Confusion Matrices (Figure 6 and 7) and then let’s look at the metrics of our interest (Precision and Recall) to make a fair comparison of their performances.
Without calculating any metrics yet, let’s focus on the TP and the FN of both matrices. For Model_15 (Figure 6), TP = 289 and FN = 51, whereas for Model_5 (Figure 7), TP = 333 and FN = 7. What this is telling us is that Model_15 will predict erroneously 51 patients by telling them that they are not sick, when in reality they are sick, and Model_5 will only do this with 7 patients. At the same time Model_15 will accurately predict that 289 patients are sick against 333 from Model_5.
If we put these numbers into our equations, we obtain the metrics on Figure 8.
The metrics suggest that there is more consistency for Model_15 values, but again, this is a medical/health problem with an unbalanced dataset, therefore, Recall is the most important metric for us, meaning that this suggests that Model_5 is outperforming Model_15, even when the Precision and Accuracy are higher for Model_15.
We could have not seen this if we only had look at the Train, Validation and Test Accuracy and scores in plots. Check them out on Figure 9 and 10 and see the similarities between both trained models with the same number of epochs (50).
Clearly, the training and validation accuracy and losses for both models are very similar, so a model selection based on their performance using these tools is not an easy task but with the help of Precision and Recall metrics, we saw that Model_5 is a better choice for this particular problem.
For those who read my previous blog, you might recall that I didn’t train all my models with the same number of epochs due to time constraint. I was running 10 epochs in average and if the Validation Score was greater than 75% then I increased the epochs to 30 and/or 50 depending on the model’s complexity. Model_5 was a victim of this tight deadline and I only trained it with 10 epochs. With more time on my hands I went back and have been doing longer trainings and that is how I realised that I should have double checked this particular model. For comparison purposes you can see on Figure 11 the accuracy and loss that resulted from epoch number 50 for each model.
Putting all of this together, it is easy to conclude that Model_5 is better suited to classify X-ray images for Pneumonia detection due to its considerably higher Recall (97.9%).
Before I close this blog I want to share three learnings that might sound obvious but could have made me chose differently in terms on which model to keep:
- Early in process, check your data’s distribution so that you can properly interpret your model after training it. An unbalanced dataset will not be properly evaluated with just accuracy but with other metrics such as precision and recall. Confusion Matrices are very useful
- Don’t over complicate the model’s architecture as it won’t necessarily translate as a better performer. Start simple and add layers as you progress
- Don’t limit yourself to just one model. Try different options so that you can compare results and reduce uncertainty
Keep in mind that my conclusions aren’t saying that the initial model is wrong, but only that for this type of problem, Model_5 is better suited
If you have made it this far, I thank you for reading my blog and any comments or feedback that you have don’t hesitate to contact me through my LinkedIn and if you want to see all the additional progress made on this project, feel free to visit my GitHub.