Should you use FastAI?

Source: Artificial Intelligence on Medium

Photo by Markus Spiske from Pexels

Recently I’ve been studying deep learning with Pytorch and FastAI. I’m always impressed how FastAI models are fast to train and really accurate with not so much code, so I wanted to see what I can get with pure Pytorch versus what I can get with FastAI.

In order to empirically test if FastAI is really that good I chose a kaggle dataset that consists of several Chest X-Ray images for classifying pneumonia.

1. Pure Pytorch

With the data downloaded from kaggle, the first thing to do is to import all packages that we’ll need.

1.1 Loading the data

Now we have to define the transformations for our data augmentation system. Here I’ll define some really simple transforms.

The data I need to fit my model is organized in folders. There is a folder named “train” and a folder named “validation”. Inside each one of those are two folders named “Normal” and “Pneumonia” which have the images for training. This is a really common way to organize data for image classification. Thankfully, Pytorch has functionality that can easily load this into a Pytorch dataset.

With the datasets and dataloaders created we can plot some data in a batch to see if everything seems in order. I’ll define a function called “show_batch” to do so.

1.2 Creating and training a model

With the datasets and dataloaders defined now I have to define a model. Pytorch provides several world class CNNs pretrained on Imagenet, so I’ll use the resnet50 model pretrained on Imagenet. I’ll use the CNN only as a feature extractor, so I’ll only train the fully connected layer of the network. In order to train a model I have to define the model, define a criterion (a loss function that will guide the training), an optimizer and a learning rate scheduler.

Finally, I have to define the training function.

Calling this function and training for 3 epochs the model achieves 93.60% accuracy and takes 3min52s to train. Pretty good huh? Now lets see how if an hybrid Pytorch-FastAI model can do it better.

2. Hybrid Pytorch-FastAI

Now, the only thing that will differ from the pure Pytorch model is that I’ll use FastAI to train my model. I’ll use the same transforms, same dataset and same model. Really, the only thing that differs is the training function.

Now I have everything I need to train this model using FastAI funcionality. FastAI learners have a really handy method that is the “lr_find()”. This method makes a search for the best learning rate to fit the model.

Jeremy Howard, founder of FastAI, suggests that a good learning rate is one that is one decade lower than the minimum of the plot. So 0.01 seems like a good learning rate.

Now, we have everything we need to fit the model. FastAI has a fit method called “fit_one_cycle” which is based on this paper (you can check this link for a simpler explanation). Basically, the one cycle policy leads to faster training.

Calling this method and training for 3 epochs the model achieves 94.79% accuracy and takes 3min50s to train. The FastAI-Pytorch hybrid model takes about the same time to train as the pure Pytorch model but it achieves a higher accuracy. FastAI is awesome! Finally lets see what accuracy a pure FastAI model achieves.

3. Pure FastAI

Now I’ll wont use the transforms I defined previously and I’ll let FastAI take care of everything. Lets see how much we need to obtain a world class result.

This is all the code we need to get the data in a suitable format for training and to define a model. Now lets do the same approach as before: search for the best learning rate and train for 3 epochs.

Again, 0.01 seems like a good learning rate. Now lets call the fit method.

Calling this method and training for 3 epochs the model achieves 96.58% accuracy and takes about 8min to train.

4. Conclusion

To wrap up, the best model was the pure FastAI one with 96.58% accuracy, the second best was the Pytorch-FastAI hybrid one with 94.79% accuracy and the “worst” was the pure Pytorch one with 93.60% accuracy.

It is really nice to see how FastAI can help you get better results for your models. Imagine that the problem you are working with is not image classification and is not so easy to create a databunch with FastAI factory methods. Even so, you can define your custom Pytorch dataset and dataloader and load them into a databunch. You also can define you really complicated model, your custom loss function, your custom optimizer and train your model with FastAI’s “fit_one_cycle” method, that proves to be better than a standard fit function. In this link I’ve done that.

Finally, I have to say that I cheated with the pure Pytorch model. I defined the learning rate for this model after I trained the Pytorch-FastAI hybrid model, so I already knew that 0.01 was a good learning rate. Probably if I had trained the pure Pytorch model with another learning rate the FastAI models would be even better.