How I Classified Images With Recurrent Neural Networks

Source: Deep Learning on Medium

Picture this: it’s Monday morning and you trudge your way to the office, sleep deprived and drinking coffee. Your boss is waiting for you at your desk, and hands you three boxes — each of them overflowing with stacks and stacks of papers containing billing information, contracts, confidential information, and everything in between. “Sort these for me, and have them on my desk by 5 p.m.,” she says.

Great. Now you have to add that to your already full plate of work. It’s not like sorting the papers is important to you… it’s just tedious and boring. Anybody could do that, and it’s the last thing you want to spend your time doing. You could be at home with the kids, and your dog!

Nobody at the office is willing to do it for you, either. They say the most creative solutions stem from people being lazy: so, what if you got a machine to do it?

Here’s the thing: machines don’t have real human emotion. They don’t care that the work they’d be doing totally sucks. The technology behind this principle is really quite simple, and would definitely get the job done by 5 p.m. so you can make your boss happy.

Disclaimer: You should know a little about Deep Learning before reading this article. I’ll briefly overview the necessary concepts but I actually wrote another article that information in more detail, which you can read here.

The 411: Recurrent Neural Networks

The technology behind sorting uses a basic Machine Learning framework called neural networks. Basically you have an input that goes through a neural network and then you obtain an output. The layers between the input and the output are called hidden layers, and allow the data to be manipulated accordingly. For visual reference, it looks a little something like this:

And it’s not just used for sorting either. There are a million and one different ways that you can use neural networks, many of which are at the forefront of technological innovation today.

Some time ago, Gmail implemented this cool new setting that basically allows it to write emails for you by predicting what you’re going to type next. This isn’t some magical technology that reads your mind… it’s actually the work of a specific type of neural net called a Recurrent Neural Network (RNN).

A really common, easy-to-understand way of explaining Recurrent Neural Networks is comparing them to human speech. When we form sentences, we don’t think of each word individually from scratch, but rather one building on the next, informing the later ones.

A visual representation of the difference between a RNN and Feed-Forward Neural Network.

The difference between an RNN and a Feed-Forward Neural Network is that in an RNN, the nodes loop the input data, which is what allows information to be outputted. In the diagram, Xt is the initial input, while Xt is the final output. The input undergoes multiple iterations (which you can see at X0, X1, and X2) which is why it’s called recurrent. It uses the output of the first iteration for the input (which affects the output) of the second. Basically, it gives context to the data.

Think of it like you’re baking a cake: you mix the each of the dry ingredients (initial input) to make a powdery mixture (first output) and then you add it the wet ingredients (second iteration input) and then you get your batter (second iteration output). But RNNs have way more than just two iterations.

This system is really useful for chains, sequences, lists, and sorting because you need to consider and classify the data as a whole.

The looped data “unrolled”.

Breaking Down Long Short Term Memory Networks (LTSMs)

Recurrent Neural Networks can also be divided into units called Long Short Term Memory (LTSM) (it’s such a paradox — I know, I’m bothered too) if there are feedback loops present, or delays of time. In this case, you would have mitigate flow of the input data through the neural network. This allows the neural network to prioritize between what is ‘important’ vs. ‘non-important’ information.

One LTSM is composed of a:

  • Memory cell: where the input data resides. It’s a container where all the action happens. The gates on its perimeter are able to control what information flows through it and how the input is handled.
  • Input Gate: This is where the input enters the cell (obviously). There is a tanh activation function because the gate decides whether to let the input data in or erase the present state, and how the input will affect the output. You can see this in the diagram below as it is represented by the middle two activation functions.
  • Output Gate: This is shown in the diagram by the activation function on the right side. It regulates and filters the output of the function.
  • Forget Gate: You don’t actually always need the previous input information for the following one. This allows you to rid of information that was previously stored. For example: say you input “Angelica is my friend. Logan is Sam’s cousin.”, then it will ‘forget’ all the data before “friend” by the time it reaches “Logan”. It is seen in the activation function furthest on the left side.
An example of a RNN with LTSMs.

LTSMs are used to classify, identify, or predict output data accordingly based on a series of discrete-time input data. They use gradient descent and backpropagation algorithms to minimize error.

If you want to learn more, here’s a broken down, step-by-step article on understanding LTSMs.

Building The Image Classification Model

The objective of an Image Classification Model is to be accurate. Why would someone use any sort of technology if it had a really large range of error, and was wrong most of the time? It’s like asking your friend who is not very strong in math to do your math test for you — it just makes no sense whatsoever.

Using my knowledge of RNNs, I coded one that classifies images — which iterates, trains, and tests data for higher accuracy. The output of the code is the loss function and percentage accuracy for each epoch, so you can see how it increases with trial and error as the weights and biases are adjusted.

I used Google Colab and PyTorch for this project. I only included the most important sections, so this is not the full code. In this link, you can see the expected outputs of all sections of the code (and accompanying explanations!).

The Basic Neural Network

This basically creates the base framework with initialized weights and biases. You can see how I coded the neurons, inputs and outputs (took the theory explained above and put it into practice). You can see what each part of the code does through the comments.

Note how self.hidden takes into account the final state of the LSTM function (ltsm_out) because it is a recurrent neural network.

Model Training

The model updates with every iteration, adjusting weights and biases in order to minimize the loss function and improve the accuracy percentage using gradient descent and the backpropagation algorithm. You can read more about that here (shameless self-promo!).

The algorithm has an optimization function (optim.Adam) that ensures the system actually improves the accuracy per epoch, as opposed to randomly adjusting weights and biases so the accuracy fluctuates.

Iteration of Data: Training and Testing Model Maximal Accuracy

Basically this that puts together the model and its testing and training. The output gives the epoch, the loss function, and the percentage accuracy with every iteration. When it is ran, we see how the neural network gradually decreases the loss function, and has the accuracy approach 100% due to the optimization function and the adjustment of weights and biases. The ‘Test Accuracy’ portion at the bottom gives a final accuracy percentage after 10 epochs (which is just a number I set, it doesn’t always have to be 10).

The Output

This tells us the accuracy of what our model predicted. It’s important to note that every time you train and test the data, you’re going to obtain different loss function costs and train accuracy results because the algorithm chooses different weights and biases per iteration.

And there you have it: image classification with recurrent neural networks!

Why You Should Care

This technology doesn’t just classify images. Used in a different setting, you can see how it could sort and sort data, just like those papers your boss gave you. Oh wait, did you forget about those? Because we didn’t, and neither did she.

Here are some examples of other applications that we’ve thought of so far:

  • Robot control
  • Speech, handwriting, and human action recognition
  • Music composition
  • Sign Language Translation
  • Protein Homology Detection
  • Predicting sub-cellular localization of proteins and medical care pathways

Actually, this can be implemented in a thousand different ways as a part of a ton of different industries. As a subset of an emerging technological field, it’s definitely something to look out for. Basically, what I’m trying to say is the potential RNNs has to disrupt the world is unlike any other, mostly due to its versatile applicability.

This project was inspired by on Medium.

Let me know what you think!

Follow me on Linkedin and Medium for more.