Decoding NLP Attention Mechanisms

Source: Deep Learning on Medium

It follows a typical encoder-decoder architecture, where both the encoder and decoder are generally variants of RNNs (such as LSTMs or GRUs). The encoder RNN reads the input sentence one token at a time. It helps to imagine an RNN as a succession of cells, one for each timestep. At each timestep t, the RNN cell produces a hidden state h(t), based on the input word X(t) at timestep t, and the previous hidden state h(t-1). This output will be then fed to the next RNN cell.

From the amazing seminal post

Eventually when the whole sentence has been processed, the last-generated hidden state will hopefully capture the gist of the all the information contained in every word of the input sentence. This vector, called the context vector, will then be the input to the decoder RNN, which will produce the translated sentence one word at a time.

I can already see that you are a bit skeptical, so let’s get right down to the problem here: is it safe to reasonably assume that the context vector can retain ALL the needed information of the input sentence? What about if the sentence is, say, 50 words long? No. This phenomenon was aptly dubbed the bottleneck problem.

Enters Attention

So how can we avoid this bottleneck? Why not feed the decoder not only the last hidden state vector, but all the hidden state vectors! Remember that each input RNN cell produces one such vector for each input word. We can then concatenate these vectors, average them, or (even better!) weight them as to give higher importance to words — from the input sentence — that are most relevant to decode the next word (of the output sentence). This is what attention is all about.

As per the tradition now, this paradigm was in fact first leveraged on images before being replicated on text. The idea was to shift the focus of the model on specific areas of the image (that is, specific pixels) to better help it in its task.

An Image Captioning application: In order to generate the next word in the caption, the model shifts its attention on relevant parts of the image.

The same idea applies to translating text. In order for the decoder to generate the next word, it will first weigh the input words (encoded by their hidden states) according to their relevance at the current phase of the decoding process.

In order to generate the word “took”, the decoder attends heavily to the equivalent french word “pris” as well as the word “a”, which set the tense of the verb.

Attention Inner Workings

The rest of this article will now focus on the inner-workings of this mechanism. We now know that in order to generate the next word of the output sentence, the decoder will take the previously generated word as input as well as an attention-based weighted sum of all the input hidden state vectors. Now, the question is: how are these weights computed ?

Let’s imagine this situation: the decoder has already generated the words “The little bird,” and it is about to yield the next word at time-step 4.