Original article was published by Harshith Nadendla on Deep Learning on Medium
Why are LSTMs struggling to matchup with Transformers?
This article throws light on the performance of Long Short-Term Memory (LSTM) and Transformer networks. We’ll start with taking cognizance of information on LSTM ’s and Transformers and move on to internal mechanisms with which they work. Let’s understand what’s happening under the hood and how Transformers are able to perform exceptionally well compared to LSTM’s.
What is RNN and how does it work?
Before learning about LSTM networks, let’s be aware of how Recurrent Neural Networks (RNN) work. RNN is a type of Neural Network where the output from the previous time step is fed as input to the current time step. Unlike feedforward neural networks, RNNs can use their internal state (memory) to process sequences of inputs. Therefore, neural RNNs are good at modelling sequence data.
In traditional neural networks, all input and output is independent on each other; whereas in cases when it’s required to have sequential information to predict the next word of the sentence, the previous words are required and hence a need to remember the previous words. Thus, RNN came into existence. They solve this issue with the help of a loop structure. The main and most important feature of RNN is the “Hidden” state that remembers information about a sequence. This makes them applicable to tasks such as unsegmented, connected handwriting recognition or speech recognition.
This loop structure allows the neural network to take the sequence of input. If you see the unrolled version, you will understand it better.
The main problem with vanilla RNN’s is that they can’t take care of long-term dependencies (i.e.) if the information from the input of initial timestep is required to produce output in the later part of the network, the memory is lost because of vanishing or exploding gradient. We’d like RNN’s to be able to store information over many timesteps and retrieve it when it becomes relevant — but vanilla RNNs often struggle to do this. Please check this out for more explanation.
How do Long Short Term Memory (LSTM) Networks work?
LSTM networks are an advanced version of recurrent neural networks, which makes it easier to remember past data in memory. Finally, LSTMs are a special type of RNNs, where you connect these units in a specific way, in order to avoid certain problems that arise in regular RNNs (like vanishing and exploding gradient which resolves long term dependency problem of RNNs). LSTM is well-suited to classify, process and predict time series, given time lags of unknown duration. It trains the model by using back-propagation over time.
Here, ‘A’ is the unit that takes an input ‘X’, outputs ‘h’ (a vector). Also, there are some vectors that are fed back to A from the previous unit (shown by right arrows from one A to another). Along with input and cell state from previous unit (the horizontal line running through the top of the cell), we have another vector coming from the previous unit (i.e.) the output of the previous unit. So, totally we have three inputs and two outputs from each unit. Note that all the A’s here refer to a single object in memory — they are drawn in this unrolled way for explanation. What’s really happening is that the outputs that come out from the right of A are fed back from the left to the same A module.
There are multiple ways in which you can use LSTMs for sequence-to-sequence tasks. As the above diagram shows, you can feed in the input sequence to an LSTM and get a sequence of h vectors. You can then process these h vectors further to do classification or regression tasks as per the requirement.
The main idea of LSTMs is the cell state, the horizontal line running through the top of the cell.
The cell state is something like a conveyor belt. It runs straight through the entire network, with only few minor linear interactions. It’s very easy for information to flow along it — remaining unchanged. This solves our long-term dependency problem. For more detailed explanation about LSTM’s, please go through Colah’s blog.
Now here’s the biggest problem. When experimental analysis is done, LSTM Networks or LSTM based Encoder-Decoder models do not work appropriately for lengthy sentences, which is because such sentences have a single latent vector as output and the last LSTM unit may not be able to capture the total essence of the sentence. Since all the words of the lengthy sentence is captured into one vector, if an output word depends on a specific input word, then proper attention is not given to it in simple LSTM based Encoder-Decoder model (like how we humans generally do). To overcome these issues, Attention Based Model came into existence.
What is a Transformer?
Transformer models are essentially attention-based models. What if instead of relying just on the context vector, the decoder has access to all the past states of the encoder? That’s exactly what attention is doing. At each decoding step, the decoder gets to look at any particular state of the encoder.
In this case, a bidirectional input is used where the input sequences are provided both forward and backward, and then concatenated in to content vector (Ci) before being passed on to the decoder. Instead of encoding the input sequence into a single fixed context vector, the attention model develops a context vector that is filtered specifically for each output time step. Content vector of each output time step changes based on its dependency on input time steps, that is take care by attention weights (alpha).
From the figure below, we can see that the BLEU scores (which is one of the scales for accuracy) of the baseline model (without attention) drops dramatically as the length of the sentences increase. However, with the incorporation of attention mechanism, this problem has been alleviated. The performance of the self-attention model is slightly lower than the RNN Encoder-Decoder with attention, and their ability to deal with long sentence is also a slightly lower. These results show that attention-based models have significant superiority over traditional encoder-decoder models, which show great performance even on long sentences.
Transformer network is entirely built upon self-attention mechanisms without using recurrent network architecture. The transformer is made using multi-head self-attention models. They are inspired from neighbourhood-like concept in CNNs. For more detailed explanation read this blog.
As discussed, transformers are faster than RNN-based models as all the input is ingested once. Training LSTMs is harder when compared with transformer networks, since the number of parameters is a lot more in LSTM networks. Moreover, it’s impossible to do transfer learning in LSTM networks. Transformers are now state of the art network for seq2seq models. Hence, we conclude with the fact that transformer networks give best accuracy and also comes with less complexity and computational cost.