Source: Deep Learning on Medium
Have you ever wondered what an LSTM layer learns? Ever wondered if it is possible to see how each cell contributes to the final output. I was curious to try and visualise this. While satisfying my curious neurons I stumbled upon a blog by Andrej Karpathy named The Unreasonable Effectiveness of Recurrent Neural Networks. If you want to get a more in-depth explanation, I would suggest going through his blog.
In this article, we are not only going to build a Text Generation model in Keras, but also visualise what some of the cells are looking at while generating text. As is the case with CNNs, it learns general features of an image like horizontal and vertical edges, lines, patches, etc. Similarly in Text Generation, LSTMs learn features like spaces, capital letters, punctuation, etc. We are going to see what features each cell in the LSTM layer is learning.
We will be using the book Alice’s Adventures in Wonderland by Lewis Carroll as training data. The model architecture will rather be a simple one with two blocks of LSTM and Dropout layers and a Dense layer at the end.
You can download the training data and trained model weights here
This is how our activations of single cell will look like. I hope you can make out the pattern in the above image. If you can’t, you’ll find it at the end of the article.
Let’s dive into the code.
Step 1: Import required Libraries
Note: I’ve used CuDNNLSTM in place of LSTM as it trains 15x faster. CuDNNLSTM is backed by CuDNN and can only be run on GPU.
Step 2: Read training data and Preprocess it
We will be removing more than one spaces with single space using regex. The
int_to_char are just number to character and character to number mappings.
Step 3: Prepare data for training
It is important to prepare our data such that every input is a sequence of characters and output is the following character.
Step 4: Building Model Architecture
Step 5: Train Model
I could not train my model for 300 epochs in one go as I was using Google Colab to train my model. I had to train it across 3 days with 100 epochs on each day by saving weights and loading them again to train from the same point that I ended it on.
If you have a powerful GPU with you, you can train the model for 300 epochs in one go. If you don’t, I would suggest using Colab as it is free.
You can load the model using code below and start training from last point.
Now to the most important part of the article — Visualising LSTM activations. We will need some functions to actually make these visualisations understandable. Let’s dive in.
Step 6: Backend Function to get Intermediate Layer Output
As we can see in Step 4 above, first and third layers are LSTM layers. Our aim is to visualise outputs of second LSTM layer i.e. third layer in the whole architecture.
Keras Backend helps us create a function that takes in the input and gives us outputs from an intermediate layer. We can use it to create a pipeline function of our own. Here
attn_func will return a hidden state vector of size 512. These will be activations of LSTM layer with 512 units. We can visualise each of these cell activations to understand what they are trying to interpret. To do that we will have to convert it into a range which can denote its importance.
Step 7: Helper Functions
These helper functions will help us visualise character sequence with each of their activation values. We are passing the activations through
sigmoid function as we need values in a scale that can denote their importance to the whole output.
get_clr function helps get appropriate colour for a given value.
The image below shows how each value is denoted with its respective colour.
Step 8: Get Predictions
get_predictions function randomly chooses an input seed sequence and gets the predicted sequence for that seed sequence.
visualize function takes as input the predicted sequence, the sigmoid values for each character in the sequence and the cell number to visualise. Based on the value of the output, character is printed with an appropriate background colour.
After applying sigmoid on the layer output, the values lie in the range 0 to 1. Closer the number is to 1, higher importance it has. If the number is closer to 0, it is meant to not contribute in any major way to the final prediction. The importance of these cells is denoted by the colour, where Blue stands for lower importance and Red stands for higher importance.
Step 9: Visualise Activations
More than 90% of the cells do not show any understandable patterns. I visualised all 512 cells manually and noticed three of them (189, 435, 463) to show some understandable patterns.
Cell Number 189 is activated for text inside quotes as you can notice below. This signifies what the cell is looking for while predicting. As shown below, this cell contributes highly to text between quotes.
Cell Number 435 is activated for a couple of words after a sentence in quotes.
Cell Number 463 is activated for first character in every word.
The results can be furthered improved with more training or more data. This just proves the point that Deep Learning is not a complete black box after all.
You can checkout the whole code on my Github profile.
This was my first attempt at writing a blog. I hope you learned something from this. Do leave a clap if you liked it.