Neural Machine translation and the need for Attention mechanism.

Source: Deep Learning on Medium


Sequence-to-sequence learning (Seq2Seq) is all about models that take a sequence as an input and outputs a sequence too. There are many examples and applications of this but today I will focus on one specific application which is a machine language translation. For eg English to Hindi.

There are many approaches you can follow to solve this particular task but the state of the art technique is the Encoder-Decoder approach(More on this later).

Technically, Encoder is a RNN unit that takes a sequence as an input and encodes it as a state vector which is then fed to a decoder with some input to predict the desired output.

Intuitively, Encoder summarizes the English sequence into a special vector which is given to decoder to translate them into equivalent Hindi outputs.

Point to be noted: Generally, this RNN unit is LSTM or GRU.

So our final architecture looks like this.

Instead of describing everything here, Let’s do it step by step:

2.Encoder Architecture:

I will go with a single LSTM unit to make life simpler but in real-world deep LSTM layers are used so that the model can learn in new possible and multiple ways.

Just for the context, LSTM unit takes three inputs and returns three outputs.

Here X, h0 and c0 are the inputs.Y, ht and ct are the outputs. This figure above is the rolled version of a LSTM unit. Basically it takes one word/character at a time and gets unrolled over time. The below figure will make it clear.

The first input is the sequence that we want to translate and the other two inputs are two vectors, cell state and hidden state.

For example, lets take this sequence. “This is a good phone”. This sequence contains 5 words, so our encoder lstm will process each word at single time steps.

In the above diagram, x1 is the input word which is ‘This’, h0,c0 are the input state vectors that are initialized randomly at the start.It will output three vectors y1 (output vector), h1 and c1(state vectors).

Intuitively, h1 and c1 here contain the information of ‘This’ word which we inputted at timestep t0. Now LSTM at time step t1 will take h1 and c1 as input with next word in the sequence which is ‘is’. Vector h3,c3 comprises information till word 3 which is ‘a’. So on till last time step 5, we will get h5 and c5 which contains information about the whole input sequence.

So now our input sequence “This is a good phone” gets converted to vector h5 and c5. We discard the output vectors (y1 to y5) because we don’t need them here. We only need output state vectors as they will be containing the information about the given input sequence.

We will now initialize our decoder with these final encoder state vectors which are h5 and c5 rather than randomly which we did with encoder LSTM. Logically also it makes sense because we want our decoder to not just start randomly but to have a sense of what the input sequence is.

3.Decoder Architecture:

Decoder LSTM will also have the same architecture as encoder but with different inputs and outputs. Now there are two things, training phase, and the inference phase. We will touch on the inference phase later. First, finish the training phase.

We have the vectors encoded by our LSTM encoder. The h0 and c0 of the decoder are not initialized to random but with the h5 and c5 which we got from the encoder.

Also to make things work, we add _START_ symbol to the start and _END_ symbol at the end of the target sequence. Now the final sequence becomes ‘_Start_ यह एक अच्छा फोन है _END_’

Here X1=_Start_ and Y1=यह with h0=h5 of encoder and c0=c5 of ecoder. This returns state vectors h1, c1 which is inputted to the decoder at the next time step and output Y1 which is inputted as ground truth to the decoder. This continues until the model encounters the _END_ symbol. At the last time step, we ignore the final state vectors of the decoder (h6,c6) because it is of no use to us, we only need output Y’s.

This technique is also called “Teacher Forcing”.More on this here.

The entire training architecture (Encoder + Decoder) can be summarized in the below diagram:

Given the final architecture, We can now predict outputs from each time step and the errors are then back propagated through time in order to update the parameters of the whole network and then compute our training loss.

4.The Inference Stage:

The task of applying a trained model to generate a translation is called inference or more commonly decoding the sequence in machine translation.

We have a trained model, now we can generate predictions based on a given input sequence. This step is basically known as inference. You can refer it as a testing phase while the above steps were training phase. At this step, we only have the learned weights and input sequence which is to be decoded.

The model defined for training has learned weights for this operation, but the structure of the model is not designed to be called recursively to generate one word at a time. For this, we need to design new models for testing our trained model. There are many ways to perform decoding.

The inference stage contains two different encoder-decoder models which will act as stand-alone models for their respective purposes.

The encoder model is simple as it takes the input layer from the encoder in the trained model and outputs the hidden and cell state tensors.

The Decoder model is a more involved one. It requires three inputs, the hidden and cell states from the encoder as the initial state of the newly defined encoder model and the encoded translation output so far. For every word, it will be called i.e we have to make this inside a loop and the loop will end once it sees the _END_ symbol.

We’ll use this inference model to finally translate the sequence from one language to another.

5.The Problem:

“Attention” is one of the recent trends in the Deep Learning Community. Ilya Sutskever, the man behind the above seq2seq architecture for machine translation mentioned that ‘Attention Mechanisms’ are one of the most exciting advancements in the above encoder-decoder approach, and that they are here to stay. But what is the problem behind that approach and what does attention solves?

To understand what attention can do for us, let’s go over the same machine translation problem above. We wanted to translate the sentence ‘This is a good phone’ to ‘यह एक अच्छा फोन है’. Remember we used an LSTM encoder to map the English sentence into final state vectors. Let’s see that visually:


We can see that vectors h5,c5 must encode everything we need to know about the source sentence. It must fully capture its meaning.

But there is a catch.

It basically captures the meaning of the whole sentence but this is when the sentence is not long. For example, the sentence that we took is only 5 words long and the normal encoder will do justice to it but when the sentence is 50 or 100 words long articles this single final vector will not be able to map the fully whole sequence. Look at it this way, If the sentence is 100 words long, the first word of the source sentence is probably highly correlated with the first word of the target sentence. But that means decoder has to consider information from 100 steps ago, and that information needs to be somehow encoded in the vector. RNN’s have this old problem of long term dependencies. In theory, architectures like LSTM should be able to deal with this, but in practice, long-term dependencies are still problematic. There are some hacks to make things better but they are not principled solutions.

And This is where ‘Attention’ comes in.

6. What it Attention:

Attention is a little modification of the previous approach. We no longer try to encode the full source sentence into a final fixed-length vector. Rather, we make use of all the middle or local vectors information collectively in order to decide the next sequence while decoding the target sentence.

For example, in the Fig-8 we will use all the h’s and c’s instead of only using h5,c5. So now if our decoder wants to decode ‘This’ into ‘यह’, it can directly access first state vectors h1,c1. This idea is known as giving more attention to the current word and hence the name Attention. Intuitively you can think that decoder will attend to the first English word when producing the first Hindi word and so on.

7.How does it work:

Let’s start translating our input English sentence from the first word i.e from ‘This’ to ‘यह’.

Right now we have all the encoded state vectors (h1,c1) to (h5,c5). We will send these vectors to fully connected simple feed-forward neural network to generate scores. These scores will represent respective state vectors. Ex, s1 will represent (h1,c1), s2 will represent (h2,c2) and so on. Now, these scores will get normalized by softmax function into w1 to w5 which will now be called attention weights. Higher the weight means decoder will give more attention to this specific word for translating. Now the model will calculate the context vector. Calculated as follows:

context_vector = (w1 * h1 + w2 * h2 + w3 * h3 + w4 * h4 + w5 * h5)

We can also do this for c’s, but for simplicity, let’s ignore c and focus on h.

Here you can see that w’s are the weights and h’s are the local state vectors. If w1 is high then more attention will be given to h1 which is the first word ‘This’.

Finally, a final concatenated vector [context_vector, previous_output] will be given to the decoder to predict the next output. Once the decoder outputs the <_END_> token, we stop the generation process. This runs in a loop and for every word, it computes context_word so that every target word is generated with its respective source word in attention.

8.The drawback of Attention:

As stated in the above paragraph, it computes context_vector for every word which comes at a computation cost. Longer the sequences more will it take time to train. Also if you see Human attention is something that’s supposed to save computational resources. By focusing on one thing, we can neglect many other things. But that’s not really what we’re doing in the above model. We’re essentially looking at everything in detail before deciding what to focus on. Intuitively that’s equivalently outputting a translated word, and then going back through all of your internal memory of the text in order to decide which word to produce next. That seems like a waste, and not at all what humans are doing. In fact, it’s more akin to memory access, not attention, which in my opinion is somewhat of a misnomer. Still, that hasn’t stopped attention mechanisms from becoming quite popular and performing well on many tasks.