Watch this Neural Network Learn to See

Source: Deep Learning on Medium

Visualizations

The neural network was trained for five epochs, with a minibatch size of 1024 images, totaling 290 training steps. After every step, a pre-selected set of ten sample images (one of each digit) were fed into the model and the activations of each convolutional layer were saved. Although it’s fallen out of fashion in recent years in favour of the more easily-trainable ReLU function, I decided to use tanh as the activation function in the convolutional layers. This is because tanh is bounded between -1 and 1, making it simple to visualize. When the activations of the first layer are applied to a red-blue colourmap, this is the result:

Conv1: The input images (top row) and the activations of the four channels in convolutional layer 1. Activations range from +1 (blue) to 0 (white) to -1 (red). Frame (top left) is the number of training steps applied.

Conv1 appears to have learned to recognize stroke width in the first and second channels, as the insides of each digit are dark red while the outsides are light red. In the third and fourth channels, it appears to have learned the concept of edges, with the digits being blue, the background being pink, and the digit edges being white. These activations are a long-shot from what the deep learning canon would suggest, however, which is that each channel would learn a clear and distinct feature, such as vertical and horizontal edges; Conv1 is largely reproducing the original input with slight annotation.

Conv2: The same setup as Conv1.

Similar to Conv1, Conv2 also appear to be reproducing the original input. Channels one, two, and four are nearly identical to each other and to the edge-highlighting behaviour seen in Conv1, and channel three is simply a fuzzy reproduction of the input.

Conv3: The same setup as Conv1, except with eight channels instead of four. This layer has half the resolution as the original image, so activations were upscaled without interpolation for visualization.

In Conv3, we see what may be the first real learned features. In the sixth channel, towards the end of training, we see that the digits are blue, most of the background is pink, and the background directly beneath each part of the digit is red. This suggests that this channel has learned to recognize the bottoms of horizontal edges. Similarly, the seventh channel has red digits, pink background, and white horizontal edges above each digit. The other channels appear to be simple reproductions of the original images, however.

Conv4: The same setup as Conv3.

In Conv4, we see more clearly defined features. In particular, we see edges at different angles. The first, second, and sixth channels identify the tops of horizontal edges. The third, seventh, and eighth channels identify diagonal edges. The other two channels are coarse reproductions of the original.

Conv5: The same setup as Conv1, except with sixteen channels instead of four. This layer has one-quarter the resolution of the original image, so activations were upscaled without interpolation for visualization.

Conv5 has had substantial downsampling, with a resolution of only 7×7 pixels, but appears to have meaningful feature extraction. At the earliest steps in training, each channel is a pink wash, largely void of information. By step 70, the layer has learned to produce a blob that vaguely resembles the input. However, by the end of training, the channels have clearly differentiated themselves from each other, and show sharp changes in activation. It’s unclear what features have been learned here due to the low resolution and entangling of what we would call independent features, but it’s clear that each channel here has something useful.

Conv6: The gif was too large for Medium, so these are the activations after training has completed.

Unfortunately, Conv6 is just over Medium’s file size limit, you’ll have to click this link in order to watch it learn. Similar to Conv5, the learned features are clearly visible, but it’s nearly impossible to tell what they actually correspond to.

Accuracy and loss (categorical_crossentropy) during training

Lessons Learned

So what’s the moral of this story? I propose there are three. First, deep learning outcomes are rarely as clear-cut as the canon suggests. Many textbooks, including Deep Learning (Goodfellow et al.), liken low-level convolutional layers to Gabor Filters and other hand-crafted computer vision filters. Despite the model achieving over 95% accuracy on the testing data, the first four convolutional layers did very little as far as feature extraction goes. Granted, this was a very simple model for a very simple task, and it’s likely that a more complex model trained for a harder task would have learned at least some useful low-level features, but the way that deep learning is typically taught (in my experience) suggests that feature refinement and extraction is inevitable, even for simple tasks; this is plainly not the case.

The second lesson is that learned features are unlikely to be intuitive, independent qualities that a human might select for. Conv5 and Conv6 have clearly learned something, and the original images have been encoded in such a way that the dense layers of the network can classify them by digit type, but it isn’t immediately obvious what they’ve learned to detect. This is a common problem in deep learning, and especially in generative modelling, where a model may learn to embed two or more seemingly unrelated qualities as a single feature.

The third lesson here is one that I’m reminded of daily in my work as a data scientist, and that’s that it pays to visualize everything. I went into this project expecting to write a very different article. I was excited to show the network learning and refining features, from low-level edge detection to high-level loops and whirls. Instead, I found a lazy layabout that hardly refined features until the eleventh hour. Most notably, I was surprised to see that, once the layers learned some representation of the input, they hardly changed over the course of the training process. Visualizing this has bolstered my understanding of convolutional neural network training. I hope you’ve learned something here as well.

For those curious, the code used to train this network and produce these visualizations is available in this Github repo: