Introduction to Transfer Learning

Source: Deep Learning on Medium

How a powerful ML technique can save you from reinventing the wheel

Image by cytis from Pixabay

How to make an Apple Pie

Ingredients —

  • Piece of land where you can grow corn, wheat, apples and Cinnamon
  • Fertilizers, farming tools and a mill
  • Saline water to extract salt from
  • A cow to make butter from the milk
  • A hen for eggs

Recipe — To get started begin by ploughing the land to prepare it for sowing the seeds. Sow the corn, wheat, apple and Cinnamon seeds. Wait for few years and watch them grow, keep watering and adding fertilizers to get a good quality produce. After the corn, wheat and apples are ripe harvest them and don’t forget to pick the bark of cinnamon trees. Get the corn and process it to make sugar meanwhile also mill and grind the wheat to make flour. Boil the saline water at medium heat until the whole water is evaporated and only salt remains. You can also start milking the cow to make the butter. Once your hen has laid eggs and you have made the butter from the cow’s milk you are all set to move to your kitchen to make the pie.

Sounds like a crazy apple pie recipe isn’t it?

“If you wish to make an apple pie from scratch, you must first invent the universe”
 — Carl Sagan

Growing the ingredients for an apple pie sounds like a fun pet project you would want to do but it is not practical to do this every time you make an apple pie. You go to the store buy apples, flour, sugar, salt and pie crust mix the ingredients together and put it in the oven and enjoy. Why reinvent the wheel by spending so much time and resources, specially when somebody has already done it for you.


Transfer Learning

Transfer learning to some extent is also about avoiding reinventing the wheel. Transfer learning is the approach in which knowledge learned in one or more source tasks is transferred and used to improve the learning of a related target task. Lets look at neural networks for instance, after spending a lot of time you or someone else has trained a neural network with millions of images to recognize cars. Should you spend the same amount of time again to gather millions of images of trucks and train a neural network from scratch to recognize trucks? or can you use the knowledge gained by the network recognizing cars to recognize trucks? Well it turns out you can!

You know that Convolutional Neural Networks are pretty good at image related tasks. Usually convolutional neural networks are very deep networks with lot of convolutional layers to extract features from the images. Moreover these networks are trained on huge datasets which consist of millions of images. For instance ImageNet which contains 1.2 million images with 1000 categories. It is not always practical to get such huge datasets. Further, deep networks can take days to weeks of training time on such huge data if you run them on specialized hardware like a GPU.

This is where we can use the learned knowledge by a network which was trained for one particular task and deploy it for another task. You don’t have to build a model starting from scratch

This article liberally uses jargon related to CNN. If you are not familiar with CNN, I would encourage you to get a primer here. A convolutional neural network is mainly comprised of two main building blocks —

  1. Feature extraction
  2. Classifier

We know that the feature extraction part in a CNN extracts features from an image. Each successive convolution layer extracts more and more complex features from the image. Lets say a neural network is trained on 1.2 million images of 1000 different objects. The initial layers extract low-level image features like edges and curves and the deeper layers capture more and more complex details, such as body parts, faces, and other compositional features. For a new task we can use the pre-trained convolutional layers of such state of the art networks and use the features off the shelf and train the classifier or the fully connected layer for the task of our choice. Let us try to understand this with the cooking analogy. This time we will make pasta!

Suppose you want to establish a small business to make and sell dried spaghetti. You talk to the vendors for sourcing the finest quality flour and eggs and meticulously plan and establish the processes required to make the pasta, package it and establish the supply chain to reach your customers. You start delivering the finest quality spaghetti to your customers and your customers are happy. But you find out that there is a huge demand for dried macaroni as well. You now plan to make dried macaroni as well. It would be stupidity to scrap all of your current equipment, vendors and supply chain and start afresh. You know that most of the process required to make a macaroni is same as spaghetti you just need a new machine to shape the final dough in a different macaroni shape! With a few tweaks to your processes and after getting a new pasta shaping machine you quickly start producing macaroni as well!


What you did with your pasta factory is to use your previous knowledge and processes, changed a few pieces at the end of the production line and started producing macaroni. You can do similar things with a CNN as well. Using a previously trained CNN and refactoring it by changing some of the layers (fully connected in most cases) is what transfer learning is all about.

Transfer Learning Techniques

We saw how we can use existing CNN model and remove the final classifier or the fully connected layer and use the features from the CNN to retrain a new classifier. There are a couple of other major strategies which are used apart from using Convolutional layers as a just feature extractor.

Fine-tuning the Convnet: Another strategy is to not only replace and retrain the classifier on top of the ConvNet on the new dataset, but to also fine-tune the weights of the pre-trained network by backpropagation. You can choose to fine-tune all the layers of the ConvNet or keep some of the earlier layers fixed and only fine-tune some higher-level portion of the network. This is motivated by the observation that the earlier features of a ConvNet contain more generic or simpler features like edges or color blobs that should be useful to many tasks, but later layers of the ConvNet becomes progressively more specific to the details of the classes contained in the original dataset. In case of ImageNet for example, which contains many dog breeds, a significant portion of the representational power of the ConvNet may be devoted to features that are specific to differentiating between dog breeds.

Use the Architecture of the pre-trained model: We can use the architecture of the model which was built by others for a different task. We keep the number of layers, kernels etc. the same and we initialize all the weights randomly and train the model according to our dataset again. People have also released intermediate checkpoints of their models which you can use to train a model with your dataset.

When and how to use transfer learning?

How do you decide what type of transfer learning you should perform on a new dataset? Should you just use a new classifier or should you also get rid of a few conv layers? Or should you keep the architecture as it is and retrain all the weights? This is decided by several factors, but the two most important ones are the size of the new dataset, and its similarity to the original dataset.

  1. The new dataset is small and similar to original dataset
    If the data is similar to the original data, we expect higher-level features in the ConvNet to be relevant to this dataset as well. For instance if the original network was trained to recognize cars you might still use this network to recognize vans since the higher level features like wheels, doors etc. would be common. Hence, the best idea might be to train a linear classifier using the features extracted from the CNN.
  2. The new dataset is large and similar to the original dataset
    If your new data set has many examples it is a good idea to fine tune the whole network with your new examples. That means you can take the pre-trained CNN and run backpropagation again with the new examples to fine tune the weights for a better accuracy.
  3. The new dataset is small but very different from the original dataset
    If you have a small dataset which is very different from the data set on which the CNN was trained on it might not be best to train the classifier form the top of the network, which will contain more specific features of the dataset on which the network was trained on. Instead, it might be better to train the new classifier from somewhere earlier in the network which should have simpler and more generic features.
  4. The new dataset is large and very different from the original dataset
    If you have a large dataset but very different from the original dataset on which the CNN was trained on, you can afford to train a ConvNet from scratch i.e. reinitialize the weights and train the ConvNet with your data. However, in practice it is still beneficial to initialize with weights from a pre-trained model.

You see how we can use these strategies to leverage the models which are trained on different datasets and use them for our advantage. This saves us lot of time, effort and resources. Transfer learning is a powerful technique and many prominent personalities in the AI and ML industry including Andrew Ng hold the opinion that transfer learning will be the next main driver of ML commercial success. So what are you waiting for, go ahead and try using a pre-trained model to your application!


X8 aims to organize and build a community for AI that not only is open source but also looks at the ethical and political aspects of it. More such simplified AI concepts will follow. If you liked this or have some feedback or follow-up questions please comment below