Original article was published on Deep Learning on Medium
Training Networks to Identify X-rays with Pneumonia
Transfer the learning before the virus!
The year 2020 has witnessed the outbreak of the pandemic, COVID-19 which has brought the entire world to a standstill. The scientific community has been continuously working towards getting a medical breakthrough for a potential cure. It has become a race against the quick spread of this virus which is also why we are seeing the fastest-ever progressing clinical trials. Data scientists across the world have been aiding this process by harnessing data. But data collection on COVID-19 patients is an on-going process.
Limited time and data being a challenge today, transfer learning seems like a good solution. It will enable us to use models which have been pre-trained on data with similar structures. For example, using models which have been pre-trained on patients having similar diseases. It also gives us an opportunity to take advantage of the learning power of deep neural networks which if trained from the ground-up would require large amounts of data and computational resources.
About the Data
The data was sourced from Kaggle. It contains chest X-ray images (anterior-posterior) selected from retrospective cohorts of pediatric patients of one to five years old from Guangzhou Women and Children’s Medical Center, Guangzhou. The goal was to classify these X-ray images as Normal or positive for Pneumonia.
The entire data was directly imported into Google Colaboratory using the Kaggle API. And all analysis was done on the same using GPU. The code can be found here.
Understanding Residual Networks or ResNets
Residual Networks (ResNets) are convolution networks which were introduced as a solution to the degradation problem generally faced while using ‘plain’ convolution networks. ResNets use ‘skip’ connections to jump over layers of the deep networks.
The skip connections can also be seen as identity mappings or functions which add the outputs of the previous layers to the later layers. In forward propagation ResNets work to push the outputs from the skipped subnetworks called ‘the residual’ to zero. This makes the true output almost equal to the output of the subnetwork from where the skip connection began thereby reducing the appreciable information loss due to a deeper architecture.
An additional benefit of ResNets is that, during backpropagation the skip connections also propagate the gradient flow. Skipping the layers with non-linear activation functions makes the initial gradient (from the higher layers) reach faster to the earlier layers which solves the problem of vanishing gradients.
For this project, ResNet50 was implemented.
The ResNet50 Architecture: The ResNet50 is a 50-layer ResNet which uses a ‘bottleneck design’ for computational efficiency. The bottleneck design indicates that they use a stack of 3 layers. These 3 layers are, 1×1, 3×3, 1×1 convolutions. 1×1 convolutions on either side are used to decrease and then restore the dimensions. Thus, the 3×3 layers become like a bottleneck with smaller input & output dimensions. The ResNet50 has more than 23 million trainable parameters. This network had been pre-trained on the ImageNet dataset consisting of 1000 classes, on a training dataset of 1.28 million images, validated on 50k images and tested on another 100k images.
‘SKIP’ to the modeling…
Adjusting the pre-trained model for our dataset
The data presented a binary classification problem with the following two categories: Normal & Pneumonia. Since, the original ResNet50 network was used to classify an image into one of the 1000 categories, to take advantage of the pre-trained architecture and it’s weights, the top part of this network was removed. Thus, the original fully connected layer was replaced with a Global average pooling layer followed by a fully connected layer dense layer and an output layer. Other combinations were tried but this gave the best test set performance.
Preprocessing for the Model
Adding more channels: The Image data generator functions in Keras were used for preprocessing the images. The Chest X-ray images of the data are grayscale images, which consist of a single channel. Whereas the ResNet50 model was trained on RGB images of the ImageNet dataset, which have 3 channels. Using the color_mode argument of the generator functions, the grayscale images were converted into having 3 channels.
More Image transformations: Further, the images in the training set were augmented using horizontal flip, zooming, height/width shift and shearing transformations. The ResNet50 preprocessing function was also applied on the augmented training images, original validation & test images.
With the above model the best test accuracy achieved was ~83%!
- The debate of batch normalization in Keras: Some literature suggests that since the Batch Normalization layer in Keras works differently in training and inference phases, it can create discrepancies in accuracy metrics. Hence, the model was trained by freezing all layers except the batch normalization layers of the ResNet50 base model. All the remaining unfrozen layers (batch normalization layers & additional layers) were then trained.
- Tuning of hyperparameters: To improve the slow convergence in the initial model, different values of learning rate and beta_1 for the Adam optimizer were tried. A learning rate of 0.01 and beta_1 of 0.9 were chosen. Looking at the batch size different powers of 2 were tried. Batch size of 32 gave the best test result.
- Custom Callback function: Further, it was observed that the model gave best test accuracy when the validation loss during the train model went below 0.1. To achieve this a custom callback function was created to train the model till the validation loss fell below 0.1 with a patience parameter of 2.
The best test accuracy achieved after fine-tuning was 93.75%. The AUROC curve was 0.98.
While, this model performed well on the test set, having additional patient data representing varied regions and demographies, and further hyperparameter tuning could improve the model results. Using model pre-trained specifically on X-ray images such as ChestXNet could also give better results. The purpose of this project was not to make any official claims, rather aid in future research. Any diagnosis should only be make by a medical professional!
According to the World Health Organization,“the most common diagnosis in severe COVID-19 patients is severe pneumonia”. This article presents a use-case of transfer learning relevant for COVID-19. A pre-trained ResNet model has been used to classify X-rays of patients as ‘Normal’ or infected with Pneumonia.