Original article was published on Deep Learning on Medium
RSNA Pneumonia Classification
Through this project, I will use the RSNA Pneumonia dataset provided by Kaggle to classify whether a patient is infected or not. You can check all the details about the competition here.
It is important to mention that the competition was mainly about the RSNA Pneumonia detection but, I used the data for the classification purpose.
Before we start, you should accept the competition rules from the Rules section in order to be able to work with the data. After accepting the rules, you should navigate to your account to download the .json file, which will be needed to be able to download the data.
Downloading and Visualizing data
The very first step in any computer vision problem is to visualize the data. We should visualize the images and check their shapes to get an intuition about the data and to make sure they are not corrupted by any means.
The images came with .dcm extension, so, we have to install pydicom to be able to work with the data. After downloading the data and installing the required packages:
We can notice the existence of the bounding boxes due to the detection purpose mentioned before.
The bounding boxes coordinates are available in stage_2_train_labels.csv file
Validating data and labels match for the classification
After investigating the labels, I found that some images are labeled twice as the infected patients can be classified into more than one class. So, I dropped the duplicates as my goal is only to classify whether the patient is infected or not. After dropping all the unnecessary columns from the train labels CSV file:
Now, I have a CSV file with only 2 columns, the patientId, and the Target without any duplicates.
Converting images into the png format
After splitting the data into train and test, I had to convert them into png images in order to be able to feed them to the DataLoader
I used this script for the conversion:
#converting the training data
inputdir = ‘/content/Train/’
outdir = ‘/content/train/’
train_list = [ f for f in sorted(os.listdir(inputdir))]
for f in train_list:
ds = pydicom.read_file(inputdir + f, force=True) # read dicom image
img = ds.pixel_array # get image array
cv2.imwrite(outdir + f.replace(‘.dcm’,’.png’),img) # write png image
Now, we finished preparing the images for the transformation and the modeling phase.
Transformation and DataLoader
I applied a very basic transformation to the data, which included resizing images to be of size 128X128, RandomHorizontalFlip, and for sure, converting them to Pytorch Tensor.
Then, I built my own DataLoader that returns the image and its label ready for the training phase. Here are the DataLoader parameters:
After examining different architectures, I found that this architecture works well with the data
I used ReLU as the activation function, dropout to prevent overfitting and BatchNormalization which, after doing some research, turned out to help in enhancing the performance with similar datasets.
criterion_scratch = nn.CrossEntropyLoss()
optimizer_scratch = optim.Adam(model.parameters(), lr=.001)
Also, I used the crossEntropyLoss as the loss function, which includes the Softmax activation function within its implementation. And finally, I used Adam Optimizer with 0.001 as the learning rate
I used Google Colab for training. So, I had to train the model for 5 epochs and save the parameters and then, load the parameters and continue for another 5 epochs. The data needed a lot of Ram and Disk and I did that to avoid the crashing problem. The first 5 epochs results
I trained the model for 15 epoch using the mentioned criteria and the final 5 epochs results:
Accuracy is not always a good indicator in these cases but that was just the start to get a glance at working with such a dataset!
This is just the start and there is a lot more to be done to increase the performance. We can increase the performance by:
- Training the model for more than 15 epoch.
- Using a deeper network.
- Applying any of the weight loss techniques to avoid the problem of the class imbalance.
- Using Transfer Learning.
All the implementation details can be found through this GitHub repository