Original article was published by Jonathan Hui on Artificial Intelligence on Medium
MetaNet has a meta-learner composed of two DNNs: DNN d and m is responsible for generating fast weights for the encoder and classifier respectively.
Step ① samples T examples from the support.
Then, MetaNet estimates the loss of its encoder using the slow weights only:
The meta-learner d will use the gradient of this embedding loss as input to infer the encode’s fast weights Q*. This concept allows d to generate fast weights sensitive to a specific task.
Step ② to step ⑤ predicts the fast weight W* for the classifier on each query sample. Intuitively, the fast weights W* for the classifier will be adapted to the current task by finding the the similarities between the support and the specific query. Let’s have an overview before detailing the steps. For each sample i in the support, we compute its encoded feature rᵢ’ and Wᵢ’* (Wᵢ* in the code). Similar to the previous step, Wᵢ’* are predicted by the meta-learner using a loss gradient. But this time, it uses the classifier loss gradient. Then we iterate through each query sample and find its attention (similarities) corresponding to each support readout rᵢ’. Then, this attention will be used to readjust Wᵢ’* to Wᵢ*, i.e. the fast parameters in the classifier will be adapted to the similarities between the support and the specific query.
Here are the fine details. In step ②, we iterate over all N samples in the support.
For each sample, we compute the loss of the classifier using the slow weights W only. The meta-learner m will use the gradient of this loss as input to infer the classifier’s fast weights Wᵢ*. But it creates one gradient and one Wᵢ* per support sample. And the result Wᵢ* will be stored in the ith position of memory M for later use.
In step ③,
we extract the features from the input using the encoder with both fast and slow weights. The result r’ᵢ will be stored in the ith position of memory R. Now R stores the representation r’ᵢ of the support with the corresponding fast weights Wᵢ* for the classifier in memory M.
In step ④,
we iterate over L queries and encode the input as features rᵢ with both fast and slow weights. MetaNet computes the similarity between the query and the support and uses it to readjust the fast weights Wᵢ* for the classifier:
- MetaNet computes the soft attention focus (similarity) between rᵢ in the query and R which stores the support features.
- MetaNet uses softmax to convert the result to a probability distribution.
- MetaNet multiplies the result with M to readjust Wᵢ*.
In short, we estimate W* based on the support and now we adopt these values to the corresponding query based on similarity.
In step ⑤, we compute the loss of the classifier which uses the adjusted fast weights and slow weights in making predictions.
In step ⑥, using the computed loss, we train the slow weights for the encoder and the classifier, as well as the meta learner’s d and m.
In a memory-based system, the meta-learner collects and encodes our experience into a memory structure. Later, memory is recalled by piecing similar objects in memory together.
In other model-based algorithms, the meta-learner generates
- weights of a complete model, or
- the context representing the support of a task.
For the first case above, the weights will become the model parameters of a classifier/regressor DNN. In the second case, the context will be used as weights for a part of the DNN, usually as the last layer. But it either case, the DNNs are trained to make predictions that adapt to a specific task quickly. During meta-testing, we simply repeat the process for an unforeseen task with the hope that the DNNs have already acquired the generic knowledge and can be readapted with a small dataset.
Let’s study the second type of Meta-learning approach that focuses on the optimizer. Again, we optimize (train) the meta-learners to adapt to specific tasks easily with minimum data samples.
LSTM-based Meta-Learner Optimizer (paper)
The algorithm below uses an LSTM R for the meta-learner. Some of the concepts are similar to what we discuss in MetaNet. R will predict model parameters for the learner M. But beside the loss gradient, it also uses the loss values, and previous parameters for M as input.
In step 7, it selects T samples from the support and makes label predictions for each sample using the learner M. For each sample, it computes its loss ℒ and the corresponding gradient ∇ℒ. The meta-learner R, parameterized by Θ, uses these computed values as input to make a new model parameter proposal for the learner M. But the adopted parameters for M at time t will be a gated combination of the previous parameters and the current proposals (step 10).
We repeat the process for T iterations and uses the last model θT as the learner M. Then, using the query samples, we make predictions with θT and use the loss gradient to update the meta-learner model parameter Θ (step 16).
Model-Agnostic Meta-Learning (MAML)
In Gradient Descent, we use the gradient of the loss or the reward function to update model parameters.
But this learns a particular task rather than finding the fundamental knowledge behind all the tasks. So instead of updating the model immediately, we can wait until a batch of tasks is completed. We later merge all we learned from these tasks for a single update. This approach fulfills the concept of “learn what we learn”.
MAML utilizes this concept to update models. It is simple and it is almost the same as the traditional DL gradient descent with one added line of pseudocode.
MAML collects a batch of weight updates from different tasks and each set of weight updates will propose a new model (step 6). Once this is done, MAML evaluates the loss again for each proposed model with samples from the query (note: the samples come from the query, not the support). MAML sums the loss and computes the gradient w.r.t. θ. Finally, the model parameters θ are updated with gradient descent.
Here is the objective that the meta-leaner is optimized for. It adopts a model for a specific task using the support and then finds θ that can most easily adapt to these tasks according to the query samples.
Conceptually, we train specific learners θᵢ below and we later use their proposal models to train the meta-learner parameterized by θ. The solid line below indicates how the meta-learner converges to its optimal.
Once the training is completed, we can use this meta-model to initialize a learner model in handling new tasks. Then, it can be further finetuned. This finetuning only needs a few examples and a single or a few gradient steps. For example, the finetuning for task 3 will move the optimal point towards the optimal θ₃*. MAML is model agnostic. But for any optimization-based methods to be more effective, the DNN f seems to be narrow and deep (AutoMeta).
SNAIL (A Simple Neural Attentive Meta-Learner paper)
SNAIL meta-learners deploy temporal convolutions (TC) to aggregate information from past experiences and use soft attention to learn which part of the current data should be focused on.
SNAIL models are composed of layers of TC and causal attention. Causality is applied such that current states or actions depend on past histories but not the future. If it is applied to supervised learning, SNAIL takes in a sequence of input and labels (except for the last entry) to predict the label for xt. If the input is an image, a CNN network will be applied to extract the features first.
A temporal convolution (TC) layer composes of many dense blocks as used in the DenseNet. In DenseNet, the input of a block comes from all previous layers — not just the last layer.
Here is the code in concatenating all these Dense blocks together as input to the next layer.
A dense block applies a causal 1D-convolution with dilation rate R and D filters. For the next TC layer, the dilation R will be doubled. This expands the receptive field temporally.
Below is how we construct a Dense Block. It contains two convolution components. One serves as a gating function to gate the output of the other convolution output in calculating its activations (step 3).
The causal attention layer in SNAIL is responsible for creating self-attention similar to the Transformer used in BERT.
The general principle allows the model to focus (pay attention) on subparts of the input features.
Queries q, keys k, and values v are generated (learned) from the input using different learned affine transformation (linear transformation).
The queries symbolize what we are interested in and the keys encode the values v information. We multiply q and k together to estimate where the needed focuses to be. And then, we mask out the values v that we should not pay attention to.
This is an extremely rough explanation of self-attention, please refer to the attention in this article for details. And here is the pseudocode in performing the self-attention.
For casualty, CausallyMaskedSoftmax(·) zeros out the appropriate probabilities before normalization such that the query cannot have access to future keys/values.
Here is the major difference between the optimizer approach and the model-based approach. In the model-based approach, the meta-learner predicts/adapts the model parameters for the learner using samples from the support, i.e. 𝜙ᵢ ← f(support samples) with f parameterized by θ. In optimizer, we use methods like gradient descent to refine θ to become 𝜙ᵢ with the support samples.
For the LSTM-based Meta-Learner, it uses a DNN to adjust θ (with inputs including the loss gradient) instead of performing gradient descent directly.
But this DNN f approach may encounter one issue. If the DNN has not been explored with tasks similar to Dtr_i, there is no promise of the accuracy of its predictions. As shown below, when the input character is smeared further, the accuracy drops for SNAIL and MetaNet (both using recurrent based DNN). This is because the input is out-of-distribution from how the DNN was trained. It shows less generalization compared with a gradient-based optimizer. On the contrary, MAML has a better inductive bias that can generalize and extrapolate to unforeseen tasks better.
A consistent meta-learner will converge to a local optimal on any new tasks, regardless of the meta-learner model. An gradient descent based optimizer solution is a consistent meta-learner as it uses gradient descent to improve the model. Even it gets a bad start from the meta-learning, it can still converge to at least a local optimal.
But for model-based models, it is not consistent. If it has not been exploited and explored properly near the data space of the new tasks, we will not reach the local optimum.
In MAML, we apply a derivative in the inner loop for each task and another derivative for the meta-learner, so it is a second-order derivative. Unfortunately, the second-order derivate may exhibit instabilities in training. FOMAML (First order MAML) simplifies the gradient calculation by simplifying the first gradient term below (w.r.t. model parameters) to contain all one.
Therefore, the gradient of FOMAML will be calculated from the second term only — a first-order loss function derivative using the updated model and testing sample B. This simplification will work well with many meta-learning problems with the exception of reinforcement learning and imitation learning. Other approaches in addressing the instability problem may involve different learning rates or training strategies between the inner loop and the outer loop.
In Reptile, this gradient calculation is even further simplified. Reptile performs a k-step model update and uses the difference in the last model and the original model as the gradient in the gradient descent.
Consider the optimal for task 1 and task 2 lay on the surface of W₁ and W₂ respectively. So Reptile moves 𝜙 towards the area that is closest to those boundaries.
MAML is model agnostic. Without further proof here, MAML with gradient descent and early stoping imply a Gaussian prior for p(𝜙ᵢ | θ) with means around θ.
This finding implies the possibility that we can implicitly define the type of prior and the corresponding learner’s ML model. For example, ALPaCA learns a feature encoding, as well as a prior p(W) generated from a meta-learner’s NN. For example, this NN output the mean and variance of a Gaussian for each weight to be used in the Bayesian linear regression. So, once features are extracted, ALPaCA computes the posterior (the prediction distribution) by applying Bayesian linear regression with the prior. For your reference, below is a general description of the Bayesian linear regression.
By splitting the process into feature extraction followed by a Bayesian linear regression, it makes the calculation to be tractable.
To demonstrate the idea of creating a meta-learner to generate model parameters for a specific optimization method (say Bayesian linear regression or SVM), we will detail the MetaOptNet. We pick this model because it is start-of-the-art technology in 2019 on meta-learning datasets and benchmarks.
MetaOptNet extracts feature from the input image and for the last layer, it learns a Base learner A to estimate the SVM weights (instead of Bayesian regression’s parameters in ALPaCA). These weights will multiply with the embedded features in the query (bitwise) in making classification predictions. This linear predictor, implemented as SVM, processes a nice generalization that results in state-of-the-art performance.
The key objective in MetaOptNet is to learn feature embeddings that generalize well under a linear classifier (SVM). A linear base learner for classification is selected because the objective function for such a learner is convex and differentiable. How to optimize this function is not only studied heavily but can be optimized efficiently.
Below are the objectives for the base learner and the meta-learner. It just looks complicated but it is pretty simple. It uses the SVM penalty (the hinge loss) to train the base learner.
The meta-learner objective is just a re-expression of
which finds the best feature extractor to work with the SVM and the testing data.
Self-Critique and Adapt (SCA)
As discussed before, in the inner loop of many meta-learning algorithms, we use N-step gradient descent with the support samples to improve a learner θᵢ (step ① below).
In SCA, it also learns a critic C parameterized by W to judge how good θᵢ is. The input to C is F. F summarizes (or extract) the model θᵢ (step ②) and may include other parameters like the predictions on the query and the embedding context of the support. After computing the score in Critic C (step ③), its gradeint will be used to improve θᵢ with gradient descent. Step ② to step ④ will be repeated for I times to improve the model. In short, the critic result is used to improve the one being criticized. Critic C is modeled without the true labels as input and therefore, it is called the label-free critic model. So, how can we train critic C if the true lables are not visible from Step ③ to step ④.
The short answer is we just delay the complete training to a later step. Once the improvement on θᵢ is commenced. SCA computes the loss on the model θᵢ’s predictions using the query true labels. The corresponding gradient will be backpropagated to train all the NNs involved in the process (step ⑤). This will include the critic model.
In meta-testing, we will use the critic C to improve θᵢ after adapting it with support samples in the meta-testing.
Metric Learning/Non-parametric Model
What is the difference between a parametric model and a non-parametric model? Parametric algorithms learn models to capture knowledge. Once a model is built, we make predictions using its parameters and we can throw away the training data. On the other hand, non-parametric models keep the data. To make predictions, we explore the similarity of the input with the accumulated samples. One of the most well-known non-parametric models is the KNN (K Nearest Neighbors) which uses the labels of the closest neighbors in making predictions. Non-parametric models may have problems with huge datasets. But it works well with Meta-Learning as the support usually contains very few samples.
In Metric Learning, the knowledge in the meta-training dataset is still captured by a parametric model, which is often in the form of feature extractors. But to adapt to a specific task, we explore similarities for the query with the labeled support. In this process, there are two questions that need to be answered: what are we comparing and how are we comparing. If we compare images pixel by pixel with L2 distances, we are going to fail.
Siamese Neural Networks (paper)
One of the most critical tasks in DL is feature extraction. If we want to generalize a classifier, the feature extractor must capture general knowledge that distinguishes classes. In a Siamese Neural Network, it uses two identical networks, sharing the same model parameters, to extract features for two samples. Then we feed the extracted features into a discriminator to tell whether both samples belong to the same class or not.
In the original paper, L1 distancing is used to measure the distance between two feature vectors. Then it is feed into a classifier (say a fully-connected layer) in determining whether they belong to the same class. But other distance metrics like the cosine similarity or L2 can be used. If the objects belong to the same class, output p should be close to 1, otherwise 0. Therefore, by computing a loss function based on the true labels and the predictions on different tasks, we train the feature extractor to extract basic features that distinguish general objects.
The diagram below shows an example of using a CNN network to extract features from an image.
Matching Network (paper)
Matching Network compares an image from the query with each image in the support. Similarities are measured with the cosine similarity after the images are encoded by f and g respectively. Then the values are normalized by the softmax function into a probability. In each task, we apply one-shot learning to map a query object to one of the classes in the support (say, German Sheperd).
The model is trained to maximize the probability of the true label while minimizing others. f and g are the feature extractor that we train. One popular realization is to have both f and g share the same parameters and implemented them as a CNN.
Fully Conditional Embedding
However, both f and g encode samples independently of others and without the context of the support S. This may hurt if images in the support are in similar subcategories, like different breeds of dogs, with many similarities. In this situation, the encoding should be sensitive to the support and extract features that can distinguish them. This requirement will be addressed by the encoding method called Fully Conditional Embedding.
In ① above, we apply g’(xᵢ) to extract image features from xᵢ — image i in the support. Usually, this is done with a CNN network. Then the features are feed into a bi-directional LSTM.
The coding g for xᵢ will be g(xᵢ, S) instead of g(xᵢ), i.e. the encoding of xᵢ will be sensitive to the support. The encoding g(xᵢ, S) is the addition of g’(xᵢ) with the hidden states hᵢ of LSTM cell i in both forward and backward direction. Since cells state cᵢ contains knowledge for the images processed so far, the encoding in hᵢ will be trained to encode data with the account of the support.
The Fully Conditional Embedding f encodes a query image using attention with LSTM.
f’ extracts the image features from the query image. Again, this can be done with a CNN network. Then we apply a K step “read” using an LSTM on f’, and the context of the support. The key idea is to provide an encoding scheme that pays attention to images in the support such that the scheme can be self-adjusted to handle the subtlety among the support images.
In specific, the hidden state in each LSTM cell is sensitive to the readout r in the last time step. This readout is based on the similarity between its hidden state and g(xᵢ) where xᵢ is the image i in the support. This encoding repeat for K timesteps and the final hidden state is the embedding f for the query image. The equations are a little bit entangled but the general principle is pretty simple. At every timestep, it pays more attention to similar images in the support and moves the extracted features closer to these features. It is repeated K times to refine the process. So the extracted features for the query will be features that close to its similar images in the support.
The details of this K step “Process” block can be found here and this is the general equations.
Relation Network (paper)
In Relation Network, the encoding of each image in the support is concatenated with the query’s encoded features. Then the DNN g is used to score their similarity. Then, it uses the highest score to associate a query to an image in the support.
Prototypical Networks (paper)
The algorithm in Prototypical Networks is very similar to clustering. The key idea is to find the centroid for objects belonging to the same class. We train an embedding function f to encode images. The centroid of objects belonging to the same class is computed by averaging their encoded features. A prediction is made by a softmax function measuring the inverse of the distance from each centroid.
Model-based v.s. optimization-based v.s. non-parametric model
Let’s compare all three meta-learning approaches a little bit. The model-based method can be modified in solving other AI domains, like reinforcement learning. But the training often starts from scratch without any hints on where to look first, i.e. no inductive bias, at least in the beginning. Therefore, it is sample inefficient. The optimization-based model can handle varying and large K well (K-shot) and it can make reasonable good extrapolation on out-of-distribution tasks. But it may have stability problems because of the 2nd-order optimization. But this issue can be mitigated. Non-parametric models are computationally fast and simple but it is harder in handling varying or large K, as shown in the empirical results. And it is mainly for classification only.
Next, we will cover Bayesian Meta-learning, Unsupervised, and Weak supervised Meta-Learning. Bayesian allows the algorithms to handle uncertainty in real life and weak supervised learning addresses the expensive cost in collecting samples. Stay tuned.