Source: Deep Learning on Medium
How data size impacts Deep Learning models and how to work with small datasets?
This is Part 2 of the series Breaking the curse of small datasets in Machine Learning. In Part 1, I have discussed how the size of the data set impacts traditional Machine Learning algorithms and a few ways to mitigate those issues. In Part 2, I will discuss how deep learning model performance depends on data size and how to work with smaller data sets to get similar performances.
PS: Thanks to Rachel Thomas for the feedback
Here are the topics I will briefly discuss in this post:
- Keys factors influencing the training of Neural Nets
- Ways to overcome optimization difficulties
- Addressing the lack of Generalization
1. Key factors in training Neural Nets
Neural networks are the basic building blocks of Deep learning models. However, Deep neural networks have millions of parameters to learn and this means we need a lot of iterations before we find the optimum values. If we have small data, running a large number of iteration can result in overfitting. Large dataset helps us avoid overfitting and generalizes better as it captures the inherent data distribution more effectively.
Here are a few important factors which influence the network optimization process:
- Optimization Algorithm: Gradient descent is the most popular optimization algorithm used for neural networks. The algorithm performance directly depends on the size of the training data. We may try updating the weights with a smaller training set(stochastic gradient descent being the extreme case when we do updates with single data point) which makes the training process fast however the updates have larger fluctuation. Training with the whole dataset makes training computationally expensive and slow. Adam, RMSprop, Adagrad, Stochastic Gradient descent are a few variations of gradient descent which optimizes the gradient update process and improves model performance. Check out this blog for detailed understanding on various versions of gradient descent. There also different optimization techniques apart from gradient descent such as the Evolutionary algorithms (EA) and the Particle swarm optimization (PSO) which might hold huge potential.
- Loss function: Loss function also plays a crucial role in the optimization process and carefully selected loss function can help in improving the training process. Hinge loss is one such example which makes training with small dataset possible.
- Parameter initialization: The initial state of the parameters greatly influences the optimization process. Poorly chosen initialization values can result in issues of divergence and getting stuck at saddle points or local minimum. Also, this increases the requirement of the training data for the training process.
- Data size: Data size is a very crucial part of training neural networks. Larger datasets can help us better learn model parameters and improve the optimization process and imparts generalization.
How to train with small data sets
We have already discussed a few of the above techniques in the last post. We will discuss the remaining techniques which are more relevant to deep learning in this part.
2. Ways to overcome optimization difficulties
1. Transfer learning:
Transfer learning refers to the approach of using the learning from one task on to another task without the requirement of learning from scratch. It directly addresses the smart parameter initialization point for training neural networks. This technique has been widely used in computer vision tasks and has been instrumental in the wide application of deep learning in the industry. It is now well established that the initial layers of models such as ResNet trained on Imagenet data learn to identify edges and corners in the image and later layers build on top of these features to learn more complicated structures. The last layer learns to classify the image into 1 of the 1000 categories. For any new problem whose data looks similar to Imagenet, we can start with pre-trained Imagenet models, change the final layers and fine tune it to our dataset. Since the lower layers feature still remain relevant, this process makes the optimization process fast and reduces the amount of data required to train new models.
Thanks to Fast.ai library, we can build images classification model using transfer learning with just a few lines of code and a few hundred training images and still get state of the art results.
In recent days, the same approach has successfully been applied to Natural language processing as well. ULMFit, BERT, AWD-LSTM, GPT2 are a few such models which can be used for transfer learning in NLP. These models are called language models as they try to learn basic language structure and predicts the next word in a sentence based on previously seen/context words. This idea of transfer learning sounds very similar to computer vision but was only tried and perfected in the last couple of years. Check this awesome blog to get a better understanding of the language models.
2. Problem Reduction:
Problem reduction approach refers to modifying the new data or unknown problem to a known problem such that it can be easily solved using existing techniques. Suppose we have a lot of voice clips which we want to classify into various classes based on the sound source. Recent advances in deep learning have shown that Sequence models such as LSTM or GRU are really good at such tasks. However, small data set can be a deal breaker and also finding a good model for transfer learning for such use case is very difficult. I was recently attending the Fast.ai v3 course and came across a smart approach for solving such problems. We can convert the voice clips into images using various libraries available such as LibROSA and can reduce it to an image classification problem. Now we can use suitable computer vision architecture with transfer learning and surprisingly this approach gives similar performance even with a very small dataset. Please check out this blog for a better understanding and code.
3. Learning with less data
a) One Shot Learning: Humans have the ability to learn even with a single example and are still able to distinguish new objects with very high precision. On the other hand, deep neural networks require a huge amount of labeled data to train and generalize. This is a big drawback and one-shot learning is an attempt to train neural networks even with small data sets. There are two ways in which we can achieve this. We can either modify our loss function such that it can identify minute differences and learns a better representation of the data. Siamese network is one such approach which is commonly used for image verification.
b) Siamese Network: Given a set of images, Siamese Network tries to find how similar two given images are. The network has two identical sub-networks with same parameters and weights. The sub-networks consist of Convolutional blocks and have fully connected layer and extracts a feature vector(size 128) towards the end. The image set which needs to be compared are passed through the network to extract the feature vectors and we calculate the distance between the feature vectors. The model performance depends on training image pairs(closer the pairs better the performance) and model is optimized such that we get a lower loss for similar images and higher loss for different images. Siamese network is a good example of how we can modify the loss function and use fewer yet quality training data to train deep learning models. Check out following video for a detailed explanation.
Another approach for One-Shot learning is to create a memory for the model to resemble the human brain. This is probably a source of inspiration behind Neural Turing Model proposed by Google Deepmind for One-Shot learning.
c) Memory Augmented Neural Networks: Neural Turing Machine is a part of Memory augmented neural networks which tries to create an external memory for a neural network which can help in One-Short learning. NTM is fundamentally composed of a neural network, called the controller, and a 2D matrix called the memory bank. At each time step, the neural network receives some input from the outside world and sends some output to the outside world. However, the network also has the ability to read from memory locations and the ability to write to memory locations. Note that back-propagation will not work if we extract the memory using the index. Hence the controller reads and writes using blurry operation, i.e it assigns different weights to each location while reading and writing. The Controller produces weightings over memory locations that allow it to specify memory locations in a differentiable manner. NTM’s have shown great potential in NLP tasks and can outperform LSTM’s and learn better in a variety of tasks. Check out this great blog to get an in-depth understanding of NTM.
d) Zero-Shot Learning: Zero-Shot learning refers to the method of solving tasks which were not a part of training data. This can really help us work with classes we did not see during training and reduces data requirements. There are various ways for formulating the task of zero short learning and I am going to discuss one such method. In this method, we try to predict the semantic representation of a given image i.e given an image, we try to predict the word2vec representation of the image class. So in simple terms, we can consider this as a regression problem where deep neural network tries to predict the vector representation of the class by processing its image. We can use standard neural architectures like VGG16 or ResNet and modify the last layers to output a word vector. This approach helps us find word vectors for unseen images and we can use it to find the image class by doing the nearest neighbor search. Instead of regression we can also have the last layer as a dot product of the image features and word vectors and find similarity which helps us learn Visual Semantic embedding model. Read the DeViSE paper for more details. Also, check out this blog to see Zero-Shot learning in action.
4. Better optimization techniques
Meta-Learning(learning to learn): Meta-learning deals with finding the best ways to learn from given data i.e learning various optimization settings and hyper-parameters for the model. Note that there are various ways of implementing Meta-Learning and let’s discuss one such method. A meta-learning framework typically consists of a network which has two models:
a) A neural network called Optimize or a Learner which is treated as a low-level network and is used for prediction.
b) We have another neural network which is called Optimizer or Meta-Learner or High-Level model which updates the weights of the lower-level network.
This results in a two-way nested training process. We take multiple steps of the low-level network which forms a single step of meta-learner. We also calculate a meta loss at the end of these steps of the low-level network and update the weights of the meta-learner accordingly. This process helps us figure out the best parameters to train on makes the learning process more efficient. Follow this awesome blog for detailed understanding and implementation.
3. Addressing the lack of Generalization
1. Data Augmentation:
Data augmentation can be an effective tool while dealing with a small dataset without overfitting. It is also a good technique to make our model invariant to changes in size, translation, viewpoint, illumination etc. We can achieve this by augmenting our data in a few of the following ways:
- Flip the image horizontally or vertically
- Crop and/or zoom images
- Change the brightness/sharpness/contrast of the image
- Rotate the image by some degree
Fast.ai has some of the best transform functions for data augmentation which makes it data augmentation task very easy using just a few lines of codes. Check out this awesome documentation to learn how to implement data augmentation using Fast.ai.
2. Data Generation:
a) Semi-Supervised Learning: A lot of times we have a large corpus of data available but only a smart part of it is labeled. The large corpus can be any publically available data set or proprietary data. In such scenarios, semi-supervised learning can be a good technique to solve the problem of less labeled data. One such approach is to build a model which learn the pattern in labeled data and tries to predict the class of unlabelled data which are called pseudo labels. Once we have these pseudo labels, we can use both labeled and pseudo labeled data to train the model for our original task. We can use a variety of supervised or unsupervised models to generate pseudo labels and can set thresholds on predicted probabilities to select suitable pseudo labeled data to train on. We can also explore active learning where the model specifies which data points are most useful for training and we can get only a small set data labeled so that the model can learn efficiently. Co-teaching and Co-learning are also similar approaches which can help in such situations.
b) GAN’s: Generative adversarial networks are a type of generative models which can generate new data which looks very close to real data. GAN’s have two components called as a generator and a discriminator which plays against each other as we train the model. The generator tries to produce fake data points and discriminator tries to identify if the data generated is real or fake. Let’s say we want to create new images of dogs. The generator creates fake dog photos which are feed to discriminator along with with real dogs photo. The objective of the discriminator is to identify correctly the real and fake images while the objective of the generator is to create images which are not distinguishable by discriminator as fake. After proper training of the model, the generator learns to generate images which looks like real dog images and can be used to create new data sets.
We can also use LSTM’s as generative models to generate texts documents or sounds which can be used as training data. Check out this blog for a detailed explanation.
In this part, we have discussed various factors which influence the training of deep neural networks and a few techniques that can help us train with a small dataset. In most of our use cases, we generally don’t have a very big dataset and such techniques open up new avenues for us to train models and get a satisfactory performance. In this blog series, I have tried to list various commonly used techniques while training will small data. However, this list is not exhaustive and can only serve as a starting point for further exploration. Please check out the links and references to get a detailed understanding and implementation of a few of the techniques discussed above.
About Me: Graduate Student, Masters in Data Science from University of San Francisco; Intern at Trulia working on computer vision & semi-supervised learning; 3+ years of experience of solving complex business problems using Data Science and Machine Learning; Interested in working with cross-functional groups to derive insights from data, and apply Machine Learning knowledge to solve complicated data science problems.LinkedIn, Portfolio, GitHub, Previous Posts
- Posts about language models, from classic approaches to recent pretrained language models
- Great results on audio classification with fastai library
- Types of Optimization Algorithms used in Neural Networks and Ways to Optimize Gradient Descent
- GANs from Scratch 1: A deep introduction. With code in PyTorch and TensorFlow
- Deep Learning Specialization
- Explanation of Neural Turing Machines
- DeViSE: A Deep Visual-Semantic Embedding Model
- Zero-Shot Learning
- From zero to research — An introduction to Meta-learning
- Fast.ai documentation