Attention Craving RNNS: A Journey Into Attention Mechanisms

Source: Deep Learning on Medium


Adding attention to your neural networks is a bit like wanting to take an afternoon nap at work. You know it’s better for you, everyone wants to do it, but everyone’s too scared to.

My goal today is to assume nothing, explain the details with animations, and make math great again (MMGA? ugh…)

Here we’ll cover

  1. Quick RNN review
  2. Quick sequence to sequence review
  3. Attention in RNNs
  4. Improvements to attention
  5. Transformer network introduction

Recurrent Neural Networks (RNN)

RNNs let us model sequences in neural networks. While there are other ways of modeling sequences, RNNs are particularly useful. RNNs come in two flavors, LSTMs and GRUs.For a deep tutorial, check out Chris Colah’s tutorial.

Let’s look at machine translation for a concrete example of an RNN.

  1. Imagine we have an RNN with 56 hidden units.
rnn_cell = rnn_cell(input_dim=100, output_dim=56)

2. We have a word “NYU” which is represented by the integer 12 meaning it’s the 12th word in the vocab I created.

# 'NYU' is the 12th word in my vocab
word = 'NYU'
word = VOCAB[word]
print(word)
# 11

Except we don’t feed an integer into the RNN, we use a higher dimensional representation which we currently obtain through embeddings.

embedding_layer = Embedding(vocab_size=120, embedding_dim=10)
# project our word to 10 dimensions
x = embedding_layer(x)

An RNN cell takes in two inputs, a word x, and a hidden state from the previous time step h. At every timestep, it outputs a new h.

RNN CELL: next_h= f(x, prev_h).

*Tip: For the first step h is normally just zeros.

# 1 word, RNN has 56 hidden units 
h_0 = np.zeros(1, 56)

This is important: RNN cell is DIFFERENT from an RNN.

There’s a MAJOR point of confusion in RNN terminology. In deep learning frameworks like Pytorch and Tensorflow, the RNN CELL is the unit that performs this computation:

h1 = rnn_cell(x, h0)

the RNN NETWORK for loops the cell over the time steps

def RNN(sentence): 
prev_h = h_0
 all_h = []
for word in sentence:
# use the RNN CELL at each time step
current_h = rnn_cell(embed(word), prev_h)
all_h.append(current_h)
 # RNNs output a hidden vector h at each time step
return all_h

Here’s an illustration of an RNN moving the same RNN cell over time:

The RNN moves the RNN cell over time. For attention, we’ll use ALL the h’s produced at each timestep

Sequence To Sequence Models (Seq2Seq)

Now you’re a pro at RNNs, but let’s take it easy for a minute.

Chill

RNNs can be used as blocks into larger deep learning systems.

One such system is a Seq2Seq model which can be used to translate a sequence to another. You can frame a lot of problems as translation:

  1. Translate English to Spanish.
  2. Translate a video sequence into another sequence.
  3. Translate a sequence of instructions into programming code.
  4. Translate user behavior into future user behavior
  5. The only limit is your creativity!

Sequence to sequence models were one of the first methods introduced to help us do translation. A seq2seq model is nothing more than 2 RNNs, an encoder (E) and decoder (D).

class Seq2Seq(object):
 def __init__():
self.encoder = RNN(...)
self.decoder = RNN(...)

The seq2seq model has 2 major steps:

Step 1: Encode a sequence:

sentence = ["NYU", "NLP", "rocks", "!"]
all_h = Seq2Seq.encoder(sentence)
# all_h now has 4 h (activations)
Encoding

Step 2: Decode the “translation”

This part gets really involved. The encoder in the previous step processed the full sequence at once (ie: it was a vanilla RNN).

In this second step, we run the decoder RNN one step at a time to generate predictions autoregressively (this is fancy for using the output of the previous step as the input to the next step).

There are two major ways of doing the decoding:

Option 1: Greedy Decoding

  1. Run 1 step of the decoder.
  2. Pick the highest probability output.
  3. Use this output as the input to the next step
# you have to seed the first x since there are no predictions yet
# SOS means start of sentence
current_X_token = '<SOS>'
# we also use the last hidden output of the encoder (or set to zero)
h_option_1 = hs[-1]
h_option_2 = zeros(...)
# let's use option 1 where it's the last h produced by the encoder
dec_h = h_option_1
# run greedy search until the RNN generates an End-of-Sentence token
while current_X_token != 'EOS':
 # keep the output h for next step
next_h = decoder(dec_h, current_X_token)
 # use new h to find most probable next word using classifier
next_token = max(softmax(fully_connected_layer(next_h)))
 # *KEY* prepare for next pass by updating pointers
current_X_token = next_token
dec_h = next_h

It’s called greedy because we always go with the highest probability next word.

Option 2: Beam Search

There’s a better technique called Beam Search, which considers multiple paths through the decoding process. Colloquially, a beam search of width 5 means we consider 5 possible sequences with the maximum log likelihood (math talk for 5 most probable sequences).

At a high-level, instead of taking the highest probability prediction, we keep the top k (beam size = k). Notice below, at each step we have 5 options (5 with highest probability).

Beam search figure found here

This youtube video has a detailed beam search tutorial!

So, the full seq2seq process with greedy decoding as an animation to translate “NYU NLP is awesome” into Spanish looks like this:

Seq2Seq is made up of 2 RNNs and encoder and decoder

This model has various parts:

  1. Blue RNN is the encoder.
  2. Red RNN is the decoder
  3. The blue rectangle on top of the decoder is a fully connected layer with a softmax. This picks the most likely next word.

Attention Mechanism

Ok, now that we’ve covered all the prereqs, let’s get to the good stuff.

If you noticed on the previous animation, the decoder only looked at the last hidden vector generated by the encoder.

Turns out, it’s hard to save everything that happened over the sequence in this single vector. For example, the word “NYU” could have been forgotten by the encoder by the time it finishes processing the input sequence.

Attention solves this problem.

When you give a model an attention mechanism you allow it to look at ALL the h’s produced by the encoder at EACH decoding step.

To do this, we use a separate network, usually 1 fully connected layer which calculates how much of all the h’s the decoder wants to look at.

So imagine that for all the h’s we generated, we’re actually only going to take a bit of each.

The scalars 0.3, 0.2, 0.4, 0.1 are called attention weights, those are generated by a small neural network in this way:

# attention is just a fully connected layer and a final projection
attention_mechanism = nn.Linear(input=h_size+x_size, attn_dim=20)
final_proj_V = weight_matrix(attn_dim)
# encode the full input sentence to get the hs we want to attend to
all_h = encoder(["NYU", "NLP", "is", "awesome"]
# greedy decoding 1 step at a time until end of sentence token
current_token = '<SOS>'
while current_token != '<EOS>':
 # attend to the hs first
attn_energies = []
for h in all_h:
attn_score = attention_mechanism([h,current_token])
attn_score = tanh(attn_score)
attn_score = final_proj_V.dot(attn_score)
 # attn_score is now a scalar (called an attn energy)
attn_energies.append(attn_score)
 # turn the attention energies into weights by normalizing
attn_weights = softmax(attn_energies)
# attn_weights = [0.3, 0.2, 0.4, 0.1]

Now that we have the weights, we use them to pull out the h’s which might be relevant for that particular token being decoded

context_vector = attn_weights.dot(all_h)
# this is now a vector which mixes a bit of all the h's

Let’s break it down into steps:

  1. We encoded the full input sequence and generated a list of h’s.
  2. We started decoding with the decoder using greedy search.
  3. Instead of giving the decoder h4, we gave it a context vector.
  4. To generate the context vector, we used another network and learnable weights V to score how relevant each h was to the current token being decoded.
  5. We normalized those attention energies and used them to mix all the h’s into 1 h which hopefully captures the relevant parts of all the hs, ie: a context vector.
  6. Now we perform the decoding step again, but this time, using the context vector instead of h4.

Attention Can Get Complicated

Types of Attention

This type of attention used only the h’s generated by the encoder. There’s a ton of research on improving on that process. For example:

  1. Use only some of the h’s, maybe the h’s around the time step you’re currently decoding (local attention).
  2. In addition to the h’s also use the h’s being generated by the decoder which we were throwing away before.

How to calculate attention energies

Another research area deals with how to calculate the attention scores. Instead of a dot product with V, researchers have also tried:

  1. Scaling dot products.
  2. Cosine(s, h)
  3. Not using the V matrix and applying the softmax to the fully connected layer.

What to use when calculating attention energies

This final area of research looked at what exactly should go into comparing with the h vectors.

To build some intuition about what I mean, think about calculating attention like a key-value dictionary. The key is what you give the attention network to “look up” the most relevant context. The value is the most relevant context.

The method I described here only uses the current token and each h to compute the attention score. That is:

# calculate how relevant that h is
score_1 = attn_network([embed("<SOS"), h1])
score_2 = attn_network([embed("<SOS"), h2])
score_3 = attn_network([embed("<SOS"), h3])
score_4 = attn_network([embed("<SOS"), h4])

But really we could give it anything we might think is useful to help the attention network make the best decision. Maybe we give it the last context vector also!

score_1 = attn_network([embed("<SOS>"), h1, last_context])

or maybe we give it something different, maybe a token to let it know it’s decoding Spanish

score_1 = attn_network([embed("<SOS>"), h1, last_context, embed('<SPA>')])

The possibilities are endless!


Implementation Details

Here are some tips to think about if you decide to implement your own.

  1. Use Facebook’s implementation which is already really optimized.

Ok, fine that was a cop-out. Here are actual tips.

  1. Remember the seq2seq has two parts a decoder RNN and encoder RNN. These two are separate.
  2. The bulk of the work goes into building the decoder. The encoder is simply running the encoder over the full input sequence.
  3. Remember the decoder RNN operates one step at a time. This is key!
  4. Remember the decoder RNN operates one step at a time. Worth saying twice ;)
  5. You have two options for decoding algorithms, greedy or beam search. Greedy is easier to implement, but greedy will give you better results most of the time.
  6. Attention is optional! BUT… the impact is huge when you have it…
  7. Attention is a separate network… Think about the network as the dictionary, where the key is a collection of things you want the network to use in deciding how relevant each particular h is.
  8. Remember you are calculating attention for each h. That means you have a for loop for [h1, …, hn].
  9. The attention network embedding dim can be made arbitrarily high. This WILL blow up your RAM. Make sure to put it on a separate GPU or keep the dim small.
  10. A trick to get large models going is to put the encoder on 1 gpu, decoder on a second gpu and the attention network on a third gpu. This way you keep the memory footprint low.
  11. If you actually deploy this model, you’ll need to implement it batched. Everything I explained here was for batch size=1, but you can scale to bigger batches by changing to tensor products and being smart about your linear algebra.

Again, most of the time, you should just use an open source implementation, but it’s a great learning experience to do your own!


Life After Attention

Turns out… the attention network by itself was shown to be really powerful.

So much so, that researchers decided to get rid of the RNNs and the sequence to sequence approach. They instead created something called a Transformer model.

At a high-level, a transformer still has an encoder and decoder except the layers are fully connected and look at the full input at once. Then as the input moves through the network, attention heads are applied to focus on what’s important.

Transformer illustration from here.

This model has largely replaced seq2seq models in translation and is behind the currently most powerful models, BERT and OpenAI’s GPT.