Attention Model: Transformers

Source: Deep Learning on Medium

Attention Model: Transformers

The typical RNN transduction language model generates a sequence of hidden states ( say h(t)) which depends on previous state ( h(t-1)) and the input at that state ( x(t)), this way it generates a sequential nature in the model which sometimes become an issue for the memory especially for long sequences because no parallelization can be done within a sequence training examples. The other transduction NN have encoder-decoder architecture where it encodes an input (x1, x2... xn) into a vector (z1, z2... zn) which in turn decodes it into (y1, y2... ym). Here model tries to capture the entire context of input vector (x) into a fixed size vector (z : context vector) which is again difficult for long sequences as z can’t store all the information of x to pass on to y resulting in information loss.

Also for large sentences, translating an entire sentence in one go seems a problem because a word in a sentence generally depend more on nearby words rather than those in farther way in the sequence. To solve these problems attention model is developed which introduces more parallelization and it models dependencies without regard to their distance in the sentence. https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf

What is attention?

Attention takes two sentences and turns them into matrix where column represent one sentence and row represent other sentence then it makes matches, identifying context and relevant relationship between words. These sentences could be anything, sentence in different languages or the same sentence along the row and column, identifying the relationship between their own words. This is called Self-attention.

https://github.com/jessevig/bertviz

Lets breakdown transformer architecture step by step and understand their role in the model.

Transformer Architecture

  1. Embedding:

Word Embedding: We use learned word embeddings to convert input and output tokens. Here we are using same embedding matrix for both input and output tokens but you can use separate embedding matrix as well. If the batch size is N and T is the sequence length then input/output token matrix is of shape (N,T) and if the embedding vector is of shape (, d_model) then final embedded matrix is of shape (N, T, d_model) .

Position Embedding: Since the model contains no recurrence or convolution, in order to capture sequence ordering we make use of position embeddings. Position embedding capture relative or absolute position of token in the sequence. Here we are using trigonometric position embedding. The length of position embedding vector is same as that of word embedding vector so that they can be added (N, T, d_model).

Trigonometric position embedding:

PE(pos,2i) = sin(pos/(10000^(2i/d)))
PE(pos,2i+1) = cos(pos/(10000^(2i/d)))
pos--> 0 to T-1
i--> 0 to (d_model/2)

The trigonometric embedding allow the model to easily learn to attend by relative positions, since for any fixed offset k,

PE(pos+k) = linear_function(PE(pos))

2. Encoder:

The final added word embedding goes into the encoder stack. The encoder unit consists of 2 sublayers: Multi-head attention(MHA) and Feed Forward(FF) layer. This set of sublayers is repeated Nx times to form encoder stack, each taking the input from previous layer. Each sublayer has a residual connection followed by Layer Normalization (LN). To facilitate the residual connection, output from each sublayer has same dimension (N, T, d_model).

3. Multi-head Attention:

The input to multi-head attention (MHA) layer is query(Q), key(K) and value(V) matrices. Generally these are same as input matrix (embedding matrix for first encoder unit) except in MHA layer of decoder stack where Q comes from previous layer and K, V comes from encoder stack (discussed later).

a) Linear:

Q --shape--> (N, Tq, d_model)
K --shape--> (N, Tk, d_model)
V --shape--> (N, Tk, d_model)
Dense layers:Q * Wq + bq = Q' --shape--> (N, Tq, d_model)
K * Wk + bk = K' --shape--> (N, Tk, d_model)
V * Wv + bv = V' --shape--> (N, Tk, d_model)

First Q, K, V matrices goes through dense layers. If the last dimension of these matrices are different, the dense layer weights are adjusted so as to make Q`, K`, V` have same last dimension which is V.shape[-1].

b) Split into batches:

Instead of performing attention on d_model dim, Q`, K`, V` are split into h batches and attention is applied in parallel to each batch.

Q` --shape-->(N, Tq, d_model)--reshape-->(N*h,Tq, d_model/h) --> Q``
K` --shape-->(N, Tk, d_model)--reshape-->(N*h,Tk, d_model/h) --> K``
V` --shape-->(N, Tk, d_model)--reshape-->(N*h,Tk, d_model/h) --> V``

c) Scaled Dot-Product Attention:

The output of self-attention layer is a weighted sum of the values, where the weight assigned to each value is determined by the dot-product of the query with all the keys.

q_k_atten = sigmoid((matmul(Q``,K``))/(d_k^0.5))q_k_atten --shape--> (N*h, Tq, Tk)d_k = d_model/hAttention(Q``,K``,V``) = matmul(q_k_atten, V``) Attention(Q``,K``,V``) --shape--> (N*h,Tq, d_k)

For large values of d_k, the dot product of Q“, K“ result in high values which pushes sigmoid to extreme regions where gradient becomes too small hence we can run into vanishing gradient problem. To avoid this, the term is scaled by dividing it with d_k ^ 0.5 . The mask array (N*h, 1, Tk) is multiplied to the matmul(Q``, K``) to randomly switch off few tokens to increase regularity in the model and is optional.

d) Concat batches and dense layer connection:

The output from self attention layers is reshaped into (N, Tq, d_k*h = d_model) . This reshaped layer goes through a dense layer to give final MHA output.

Attention(Q``,K``,V``) --reshape--> (N, Tq, d_model) --> AoDense layers:Ao * Wo + bo = A --shape--> (N, Tq, d_model)

4. Feed-Forward Layer:

FFN(x) = max(0, xW1 + b1) * W2 + b2
W1 --shape--> (d_model, dff)
W2 --shape--> (dff, d_model)
here x --> output from MHA
and dff --> dim of intermediate dense layer

5. Decoder:

Unlike encoder unit, decoder unit has 3 sublayers- Masked MHA, MHA and FFN. The Masked MHA is same as usual MHA except the matmul matrix of Q and K is multiplied by a masked matrix of shape (Tq, Tk) before it is softmax-ed and multiplied by value matrix. This masked matrix ensures that there is no illegal connections meaning each position in the decoder will attend to all positions in the decoder up to and including that position.

For example:

Consider an english sentence in decoder part, say “This is a good book”. The attention matrix of this sentence will be multiplied by the masked matrix given on left. Here, while attending to word “This”, decoder is unaware of the other words therefore masking those words, similarly while attending word “good”, decoder is aware of words “This”, “is”, “a”, “good” but unaware of word “book” therefore masking it. This enables appropriate weighting of value matrix since word depend only those words that have been occurred before that particular word.

In the second MHA layer of decoder, the query comes from previous masked MHA and key and value comes from last encoder layer output. This is followed by FF network. Each of the three units are wrapped by a residual connection and LN. This decoder unit is repeated Nx times, each taking output from previous unit. Note that while training, final output and input of the decoder are shifted by 1 position. Example:

Input: <start> This is a good book. <end> <pad>
Output: This is a good book. <end> <pad> <pad>

Final Architecture:

Training:

The source and target documents are tokenized and trimmed/padded to a specific sequence length T before it is passed through the model. The loss function of the model is sparse_categorical_crossentropy as we have integer target tokens. The validation of the model is done using average Bleu score of the translated document vs the actual target document.

Prediction: