Source: Deep Learning on Medium
Understanding RNNs, LSTMs and GRUs
A recurrent neural network (RNN) is a variation of a basic neural network. RNNs are good for processing sequential data such as natural language processing and audio recognition. They had, until recently, suffered from short-term-memory problems. In this post I will try explaining what an (1) RNN is, (2) the vanishing gradient problem, and (3) the solutions to this problem known as long-short-term-memory (LSTM)and gated recurrent units(GRU).
What is an RNN?
First, lets cover the basic neural network architecture, neural nets are trained with 3 basic steps:
(1) A forward pass that makes a prediction.
(2) A comparison of the prediction to the ground truth using a loss function. The loss function outputs an error value.
(3) Using that error value, perform back propagation which calculates the gradients for each node in the network.
In contrast, an RNN contains a hidden state that is feeding it information from previous states:
The concept of a hidden state is analogous to integrating sequential data in order to make a more accurate prediction. Consider how much easier it is to predict the motion of a ball if your data is still shots of the ball in motion:
With no sequence information, it is impossible to predict where it is moving, in contrast, if you know the previous locations:
Predictions will be more accurate. The same logic is applicable to estimating the next word in a sentence, or the next piece of audio in a song. This information is the hidden state, which is a representation of previous inputs.
Vanishing Gradient Problem
However, this becomes problematic, to train an RNN, you use an application of back-propagation called back-propagation through time (BPTT). Since the weights at each layer are tuned via the Chain Rule, their gradient values will exponentially shrink as it propagates through each time step, eventually “vanishing”:
To illustrate this phenomenon in an NLP application:
And you can see that by output 5, the information from “What” and “time” have all but disappeared, how well do you think you can predict what comes after “is” and “it” be without these?
LSTM and GRU as solutions
LSTMs and GRUs were created as a solution to the vanishing gradient problem. They have internal mechanisms called gates that can regulate the flow of information.
For the LSTM, there is a main cell state, or conveyor belt, and several gates that control whether new information can pass into the belt:
In the above problem, suppose we want to determine the gender of the speaker in the new sentence. We would have to selectively forget certain things about the previous states, namely, about who Bob is, and whether he likes apples, and remember other things, that Alice is a woman and that she likes oranges.
Zooming in, the gates in an LSTM do this as a 3-step process:
(1) Decide what to forget (state)
(2) Decide what to remember (state)
(3) The actual “forgetting” and update of the state
(4) Production of the output
To wrap up, in an LSTM, the the forget gate (1) decides what is relevant to keep from prior steps. The input (2) gate decides what information is relevant to add from the current step. The output gate (4) determines what the next hidden state should be.
For the GRU, which is the newer generation of RNNs, it is quite similar to the LSTM, except that GRUs got rid of the cell state and used the hidden state to transfer information. It also only has two gates, a reset gate and update gate:
(1) the update gate acts similar to the forget and input gate of an LSTM, it decides what information to keep and which to throw away, and what new information to add.
(2) the reset gate is used to decide how much of the past information to forget.
Illustrations were taken from
1. MIT 6.S094: Deep Learning for Self-Driving Cars taught in Winter 2017 by Lex Friedman
2. Illustrated Guide to LSTM’s and GRU’s: A step by step explanation by Michael Nguyen