Predicting relative location of CT slices on axial axis using deep learning

Original article was published on Deep Learning on Medium

Running PCA shows that actually out of 382 variables, 212 can explain 95% of the variance in the dataset. In subsequent examples, we still decided to use all of the 383 features for training.

Modeling

For all the modelling strategies detailed below, I followed these steps:

  1. Define a model
  2. Pick some hyperparameters
  3. Train the model
  4. Make predictions on samples
  5. Evaluate on test dataset
  6. Record the values for the loss function and hyperparameter values using Jovian.ml

M1. Linear regression

For a regression problem, the first candidate should always be linear regression. PyTorch allows to very easily train a LR model using gradient descent.

class CTslicesModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, output_size)

The validation loss after training for 40 epochs is 6.7295. This is quite high.

M2. Neural network with dense hidden layers

A simple neural network consisting of fully connected layers with decreasing number of neurons (1000, 500, 250, 100) generated some noticeable improvements over LR.

class CTslicesModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(input_size, 1000),
nn.PReLU(),
nn.Linear(1000, 500),
nn.PReLU(),
nn.Linear(500, 250),
nn.PReLU(),
nn.Linear(250, 100),
nn.PReLU(),
nn.Linear(100, output_size)
)

The validation loss after training for30 epochs is 0.6804. With this model, we are getting a nearly 10x better loss function value with 10 less epochs used for training compared to LR. The performance on the test set is very satisfactory

Test 1 Target: tensor([54.7135]) Prediction: tensor(54.3984)

Test 2 Target: tensor([59.4413]) Prediction: tensor(60.1773)

Test 3 Target: tensor([81.8738]) Prediction: tensor(82.4381)

Conclusions

In this tutorial, we showed that for a complex regression task, a neural network made of fully connected layers performs the best.

I would have liked to try out GANs in this project, but it became apparent that this dataset was not suitable for that. One idea would have been to train a neural network model for regression adversarially. As discussed elsewhere, using a very large and complex model like GANs for the task of regression is like killing moquitoes with a bazooka.