Multi-Task Learning with Pytorch and FastAI

Original article was published on Artificial Intelligence on Medium

Multi-Task Learning with Pytorch and FastAI

0. Intro

Following the concepts presented on my post named Should you use FastAI?, I’d like to show here how to train a Multi-Task deep learning model using the hybrid Pytorch-FastAI approach. The basic idea from the Pytorch-FastAI approach is to define a dataset and a model using Pytorch code and then use FastAI to fit your model. This approach gives you the flexibility to build complicated datasets and models but still be able to use high level FastAI functionality.

Multi-Task Learning (MTL) model is a model that is able to do more than one task. It is as simple as that. In general, as soon as you find yourself optimizing more than one loss function, you are effectively doing MTL.

In this demonstration I’ll use the UTKFace dataset. This dataset consists of more than 30k images with labels for age, gender and ethnicity.

Age: 61 , Gender: Female, Ethnicity: Indian

If you want skip all the talking and jump to the code, here is the link.

1. Why Multi-Task Learning?

When you look to someone’s picture and try to predict age, gender and ethnicity, you’re not using completely different parts of your brain right? What I’m trying to say is that you don’t try to understand the image in 3 different ways, each one specific to each task. What you’re doing is using a single understanding that your brain makes of the image and then trying to decode that understanding into age, gender and ethnicity. And besides that, there is knowledge from gender estimation that might help on age estimation, there is knowledge from ethnicity estimation that might help on gender estimation and so forth.

So, why MTL? We’re confident that training a single model to do all tasks we’re interested in yields better results than training a model for each task. Rich Caruana summarizes the goal of MTL pretty nicely: “MTL improves generalization by leveraging the domain-specific information contained in the training signals of related tasks”.

As this is a practical tutorial for training MTL models I’ll not dig deeper into theory and intuition but if you want to read more, check out this amazing post from Sebastian Ruder.

2. Creating the Datasets, Dataloaders and Databunch

When you’re working on a Deep Learning problem, usually the first thing to care about is how to provide the data to your model so that it can learn. For this specific MTL problem, the Pytorch dataset definition is pretty straight forward.

Dataset class definition

You may be questioning why did I take the logarithm of the age and then divide it by 4.75. Taking logs means that errors in predicting high ages and low ages will affect the result equally. I’ll explain why dividing by 4.75 later. Another thing to mention is that I’m using FastAI’s data augmentations here (I think they’re better than the ones from torchvision.transforms).

Once the dataset class is defined, creating the dataloaders and databunch is really easy.

Creating the datasets, dataloaders and databunch.

3. Creating the Model

Remember that our goal here is to, given an image, predict age, gender and ethnicity. Recall that predicting age is a regression problem with a single output, predicting gender is a classification problem with two outputs and ethnicity is a classification problem with 5 outputs (in this specific dataset). With that in mind, we can create our MTL model.

You can see that the model uses only one encoder (a feature extractor) and feeds the encodings (features) to the task specific heads. Each head has the appropriate number of outputs so that it can learn its task.

Why am I applying a sigmoid on the result from the ages head? That’s related to the division by 4.75 I mentioned previously and I’ll talk about that on the next session.

4. What about the loss function?

The loss function is what guides the training, right? If your loss function is not good, your model won’t be good. In a MTL problem, usually what you’ll try to do is to combine somehow each loss for each task. In our problem, the losses could be Mean Squared Error for predicting age and Cross Entropy for predicting both gender and ethnicity, for instance. There are some common ways of combining the task specific losses for a MTL problem. I’ll talk about three of them.

The first take is to calculate the losses for each task and then add them together, or take the mean. Although I have read on some discussion forums that this approach works fine, that was not what I concluded from my experiments. As the losses may have different magnitudes, one of the losses can take control of the training and you don’t get nice results.

The second take is to try to weight the losses by hand and then sum/average them together. It turns out that this approach is pretty fiddly and can take too much time and I also couldn’t make it work.

The third take, on the other hand, led me to nice results. It consists of letting the model learn how to weight the task specific losses. In the paper Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics the authors propose such loss function. Here’s my implementation of the paper’s proposal for this specific problem.

We propose a principled approach to multi-task deep learning which weighs multiple loss functions by considering the homoscedastic uncertainty of each task. This allows us to simultaneously learn various quantities with different units or scales in both classification and regression settings.

Now I should explain why did I choose to divide the logarithm of the age by 4.75 (on the __getitem__ method from the dataset class) and why did I apply a sigmoid to the output of the age’s head (on the model definition).

In theory the loss function should be able to learn the weights and scale each task’s loss. But in fact, in my experiments I concluded that keeping the task specific losses kind of in the same scale helps a lot in the fitting process. So I divided the logarithm of the ages by 4.75 because that is the maximum value of the log(age). So the result from log(age)/4.75 should be a number between zero and one and then the MSE loss wouldn’t be so bigger than the other losses. I applied a sigmoid after the age’s head so that I can force my model to always output a prediction in the acceptable range.

5. Creating the Learner and Training

With the databunch, model and loss function defined its time to create the learner. I’ll add some metrics to help me keep track of the model’s performance on the validation set.

With the learner defined, now we can use the FastAI functionality to train our model. We can use the learning rate finder, the fit_one_cycle method, the discriminative learning rates training etc.

As I’m using the resnet34 encoder pretrained on Imagenet, first I’ll train only the heads and the batch norm layers on the encoder. Training for 15 epochs led me to a 0.087 RMSE for age prediction, 89.84% accuracy on gender prediction and 84.15% accuracy on ethnicity prediction.

After that I unfreeze the encoder and train the hole model with discriminative learning rates for 100 epochs, and that led me to a 0.058 RMSE for age prediction, 99.42% accuracy on gender prediction and 99.19% accuracy for ethnicity prediction.

6. Conclusion

In this post I presented the basics you need to train a MTL model using a hybrid Pytorch/FastAI approach. Once again I think I’ve shown that FastAI’s high level functionality is really useful and made my life much easier. Using advanced deep learning features as discriminative learning rates and one cycle scheduling so easily is a thumbs up for FastAI in my opinion.

It is important to mention that in this tutorial my focus was to help you to solve a MTL problem. The UTKFace dataset is pretty biased to white people and people with ages between 20 and 40 years so you should keep that in mind. In my final implementation I made some minor adjustments on the loss function which resulted in a less biased performance of the model.

If you liked this post, please give it some claps. You can check more about what I’m doing by visiting my github and get in touch through LinkedIn.

Thank you for reading!