Source: Deep Learning on Medium
In the previous post, we discussed attention based seq2seq models and the logic behind its inception. The plan was to create a pytorch implementation story about the same but turns out, pytorch documentation provides an excellent procedure here. So here, I move onto the next item in my plan — the transformer — which works on the principle of Self Attention.
Let’s do a two-line recap of the attention-based model. Its primary ideology was that it took an input sequence and all the hidden states associated with it and at every instance of the output, it decided which part of the input was useful, and subsequently decided the output based on that. The sequential nature was captured by using either of RNNs or LSTMs in both the encoder and the decoder.
In the Attention is all you need paper, the authors have shown that this sequential nature can be captured by using only the attention mechanism — without any use of LSTMs or RNNs.
In this post, we will follow a similar structure as in the previous post, starting off with the black box, and slowly understanding each of the components one-by-one thus increasing the clarity of the whole architecture. At the end of this, you will have a good grasp of each component of the transformer network and how each of them helps in getting the desired result.
Pre-requisites: Basic knowledge about Attention networks and other encoder-decoder based models. Neural networks and Normalization techniques.
The Transformer Black Box:
Let us first understand the basic similarities and differences between the attention and the transformer models. Both aim to achieve the same result using an encoder-decoder approach. The encoder converts the original input sequence into its latent representation in the form of hidden state vectors. The decoder tries to predict the output sequence using this latent representation. But the RNN based approach has an inherent flaw. Due to the fundamental constraint of sequential computation, it is not possible to parallelize the network, which makes it hard to train on long sequences. This, in turn, puts a constraint on the batch size that can be used while training. This has been alleviated by the transformer and we’ll soon learn how. So let’s just dive right into it.
(I’d like to call out that this post has a little bit of extra reading and lesser images but I’ll ensure that the reading is as simple as possible.)
The transformer architecture continues with the Encoder-Decoder framework that was a part of the original Attention networks — given an input sequence, create an encoding of it based on the context and decode that context-based encoding to the output sequence.
Except for the issue of not being able to parallelize, another important reason for working on improvement was that the attention-based model would inadvertently give a higher weight-age to the elements in the sequence closer to a position. Though this might make sense in the sense of understanding the grammatical formation of various parts of the sentence, it is hard to find relations between words far apart in the sentence.
Like the last post, we will incrementally arrive at the self-attention model, starting off with the basic architecture and slowly getting to the complete transformer.
As seen, it follows the encoder-decoder design, meanwhile replacing the LSTMs with Self Attention layer and the sequential nature being identified using the Positional Encodings. One important point to remember is that all these components are only made of fully connected (FC) layers. Since the whole architecture is FC layers, it’s easy to parallelize it.
So we are now left with the following doubts:
- How is the input sequence designed?
- How is the sequential nature handled using the positional encodings?
- What’s Self Attention?
- What’s Layer Normalization?
- What is the design of the Feed Forward Neural Net?
- What is the difference between Self Attention and Masked Self Attention Layers?
- What’s the Enc-Dec Attention layer?
As we clarify each of these questions, the model will become clearer.
Given a sequence of tokens x1, x2, x3, … , the input sequence corresponds to an embedding of each of these tokens. This embedding could be something as simple as one-hot encodings. For eg. in the case of a sentence, x1, x2, x3, …would correspond to the words in the sentence. The input seq could then be the one-hot enc of each of these words.
As there is no component in this new encoder-decoder architecture which explains the sequential nature of the data, we need to inject some information about the relative or absolute position of the tokens in the sequence. This is the task of the positional encoding module.
Ignoring the mathematical formulation, given an embedding for token x at position i, a positional encoding for the i’th position is added to that embedding. This injection of position is done such that each positional encoding is distinct from any other. Every dimension of the positional enc corresponds to a sinusoid wavelength, with the final enc being the value of each of these sinusoid waves at the i’th point.
As described by the authors of “Attention is All You Need”,
Self-attention, sometimes called intra-attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence.
This layer aims to encode a word based on all other words in the sequence. It measures the encoding of the word against the encoding of another word and gives a new encoding. The method in which this is done is a little complex and I’ll try to break it down as far as possible.
Given an embedding x, it learns three separate smaller embeddings from it — query, key and value. They have the same number of dimensions. What I mean by learns is that during the training phase, the Wq, Wk, and Wv matrices are learnt based on the loss which has back-propagated. x is multiplied with these three matrices to get the query, key and value embeddings.
For the purpose of understanding these terms, let’s treat these words as living entities. Each word wants to know how much every other word values it and then create a better version of itself that represents this valuation.
Say x1 wants to know its value with respect to x2. So it’ll ‘query’ x2. x2 will provide the answer in the form of its own ‘key’, which can then be used to get a score representing how much it values x1 by taking a dot product with the query. Since both have the same size, this will be a single number. This step will be performed with every word.
Now, x1 will take all these scores and perform softmax to ensure that the score is bounded while also ensuring that the relative difference between the scores is maintained. (There is also the step of dividing the score before softmax by the square root of the d_model — embedding dimension — to ensure stable gradients in case the score is too large in cases where d_model is a large number.)
This scoring and softmax task is performed by every word against all other words. The above diagram paints a picture of this whole explanation and will be more easily understood now.
x1 will now use this score and the ‘value’ of the corresponding word to get a new value of itself with respect to that word. If the word is not relevant to x1 then the score will be small and the corresponding value will be reduced a factor of that score and similarly the significant words will get their values bolstered by the score.
Finally, the word x1 will create a new ‘value’ for itself by summing up the values received. This will be the new embedding of the word.
But the Self Attention layer is a little more complex than just this. The transformer architecture has added in Multi-head attention to this layer. Once this concept is clear we will move closer to the final transformer design.
Multiple such sets of query, key and value are learned using the same original x1, x2, etc. embedding. The self-attention steps are performed on each set separately and new embeddings v’1, v’2, etc. are created for each set. These are then concatenated and multiplied with another learned matrix Z (which is trained jointly) which reduces these multiple embeddings to a single embedding for x’1, x’2, etc. This is why this is referred to as Multi-head attention, each v’ representing a head of the self-attention model.
The intuitive reason for this design is that each of the heads looks at the original embedding in a different context as each of the Q, K, V matrices are initialized randomly at the beginning and then modified based on the loss backpropagated during training. So the final embedding is learnt taking into consideration various contexts at the same time.
So at the end of this whole section on self-attention, we see that the self-attention layer takes as input a position injected naïve form of embeddings and outputs more context-aware embeddings. With this, we’ve cleared up the most difficult of the above-specified questions.
The key feature of layer normalization is that it normalizes the inputs across the features, unlike batch normalization which normalizes each feature across a batch. Batch norm has the flaw that it imposes a lower bound on the batch size. In layer norm, the statistics are computed across each feature and are independent of other examples. It has been seen to perform better experimentally.
In the transformers, layer normalization is done with residuals, allowing it to retain some form of information from the previous layer.
Feed-Forward Neural Net
Each encoder and decoder block contains a feed-forward neural net. It consists of two Linear layers with a relu activation between them. It is applied to each position separately and identically. Hence the input to it is a set of embeddings x’1, x’2, etc. and the output is another set of embeddings x’’1, x’’2, etc. of the same dimensions mapped to another latent space which is common to the whole language.
With this, the encoder side of the network should be clear. Now we are only left with two of the questions we started out with — Masked self attention and Enc-Dec attention.
At any position, a word may depend on both the words before it as well as the ones after it. For eg. in “I saw the ______ chasing a mouse.” we would intuitively fill cat as that is the most probable one. So while encoding a word, it needs to know everything that comes in the whole sentence. Which is why in the self-attention layer, the query was performed with all words against all words. But at the time of decoding, when trying to predict the next word in the sentence (which is why we have a shifted input sequence for decoder input), logically, it should not know what are the words which are present after the word we are trying to predict. Hence, the embeddings for all these are masked by multiplying with 0, rendering any value from them to become 0 and only predicting based on the embeddings created using the words which came before it.
This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.
If you observe the network above carefully, the input from the encoder module is coming in from the encoder in this layer. Before this, the decoder has used the information from whatever information can be gained from its previous predictions and learnt new embeddings for those words. At this layer, it uses the encoder to get a better understanding of the context and the complete sentence as a whole in the original sequence. How does it do that?
The Decoder queries all the existing words with the encoded embeddings of the words in the original sequence, which carry both the positional information as well as the contextual information. The new embeddings are now injected with this information and the output sequence is now predicted based on this.
With this, we have now cleared up the transformer architecture, but there is one part which I still haven’t specified.
This is the architecture of the transformer we have until now. What we need to note is that the output of the encoder is an improved version of the original embeddings. So we should be able to improve it further by adding more. This is the point that is leveraged in the final design of the transformer network.
The original architecture has a stack of 4 such encoders and decoders. Only the output of the final encoder is taken for the decoder input.
We started off with the basic design of a transformer with the various layers in a single encoder and decoder set. We then understood the details of how each layer functions. Within this, we also covered the multi-head architecture of the self-attention layer. Once both the encoder and the decoder were clear, we moved on to the final piece in the architecture, the encoder-decoder stack.
This completes the detailed design of the Transformer architecture. I hope this had all the details that are needed to understand the complete picture. One of the most prevalent examples of a transformer-based network is the BERT model for the task of Natural Language Processing tasks.
I haven’t covered the training and prediction details of the network as that would take another complete post and this was easily the more interesting and complicated part for me. In the next post, I’ll cover something about how we can use transformers for images and the task of unsupervised learning.
I hope you found this useful and easily understandable. In case of any corrections or any kind of feedback, I’d love to hear from you. Please comment here and let me know or email me @ email@example.com