Visualizing Neural Networks using Saliency Maps in PyTorch

Source: Deep Learning on Medium

Visualizing Neural Networks using Saliency Maps in PyTorch

Neural networks are being used in a lot of applications and their use cases are increasing at an astonishing rate. There is a growing need that neural networks need to be interpretable to humans. Understanding what a neural network is looking for and how much robust it is to new examples is not only useful in explaining its decisions but also for scientific curiosity.

What are Saliency Maps?

In this blog post, we’ll be discussing about saliency maps — they’re heatmaps that highlight pixels of the input image that most caused the output classification.

Suppose we’ve a trained ConvNet for the problem of image classification. This ConvNet would produce some class scores and on the basis of maximum score, we would get some output class for an input image.

Class scores, by the way, are the values in the output layer that the neural network assigns to classes before the softmax, so they’re not probabilites, but they’re directly related to probabilities through a function like softmax.

Now, consider the gradient of the output class score with respect to the input image pixel values. The pixels for which this gradient would be large (either positive or negative) are the pixels which need to be changed the least to affect the class score the most. One can expect that such pixels correspond to the object’s location in the image. That’s the basic idea behind saliency maps.

Saliency Map Extraction in PyTorch

Firstly, we need a pretrained ConvNet for image classification. Here, we’ll be using the pretrained VGG-19 ConvNet. In PyTorch, this comes with the torchvision module.

VGG-19 is a convolutional neural network that has been trained on more than a million images from the ImageNet dataset. The network is 19 layers deep and can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals. As a result, the network has learned rich feature representations for a wide range of images. The network has an image input size of 224-by-224. [4]

Imports and code for using pretrained VGG-19 model. Note that since we don’t need to find gradients with respect to the parameters of the network, so we’re setting param.requires_grad to False.

We need an input image for which we would extract the saliency map. As an example, we’ll be working with this image of amazingly cute Maltese dog.

Image borrowed from https://specials-images.forbesimg.com/imageserve/5db4c7b464b49a0007e9dfac/960×0.jpg?fit=scale
Downloading and opening the image

Now we need to preprocess the opened image so that we can then feed it to the ConvNet. We need to resize the image to the required size of 224 * 224 and then convert our image to a pytorch tensor. Note that the ToTensor function converts a PIL Image or numpy array (Height x Width x Channels) in the range [0, 255] to a torch float tensor of shape (Channels x Height x Width) in the range [0, 1]. We also need to normalize the tensor using the mean and standard deviation of the images of the ImageNet dataset. All this can be implemented as follows.

Functions for preprocessing, deprocessing and displaying the image.

Saliency map is the gradient of the maximum score value with respect to the input image. But note that the input image has 3 channels, R, G and B. To derive a single class saliency value for each pixel (i, j), we take the maximum magnitude across all colour channels. This can be implemented as follows.

Results

We get the following heatmaps which show where the ConvNet is actually looking for in the input images to predict classes for each of them.

We see that the heatmap is closely concentrated over the dog which is reassuring as it means that the ConvNet is looking at the dog only and not anything else to make its prediction.

Here, we see that the heatmap is more diffused as compared to the previous heatmap and it may happen that the ConvNet may perform worse on this category, if other regions which are lit up here, don’t show up in new images.

This one too looks reassuring as we see that the ConvNet is looking only at the football helmet to make its prediction.

We see that even some trees in the background and the circular shape of the fountain base in the foreground are responsible for driving the ConvNet’s decision in predicting, that this image is of a fountain. This kind of bias may have peeped in because the ConvNet might have been trained on a lot of images of fountains with circular fountain bases and with trees in background.

This one also looks comforting that the ConvNet’s looking for Siamese Cat’s facial features to make its classification. Note that the heatmap is more lit up near the cat’s face as compared to its body which means that it’s the distinct facial features of the cat that are more important in predicting the output of the classifier.

Although this one looks pretty reassuring too, but still we see some diffusion which may have been caused by the fact that most images of sea lions are taken near water and hence the ConvNet may have picked up this pattern that the surrounding water is also important to predict that it’s a sea-lion.

This also looks pretty good as we see that mostly the heatmap is concentrated over the body of this American Eskimo dog.

Conclusion

We saw how saliency maps can tell us where the neural network is looking for in the input image while predicting an output class for it. It is important to note that the saliency maps are extracted using a classification ConvNet trained on the image labels, so no additional annotation is required.

We used image gradients to generate saliency mappings in this post. Using these image gradients, we can also generate adversarial examples by making changes in the input image in such a way so as to drive the ConvNet’s output towards an incorrect class. This can be really discomforting if the changes are so subtle that the output class looks differently to humans as it does to ConvNets.

Also, it turns out that using saliency maps and a graph-cut algorithm [5], one can also perform object segmentation in these images without the need to train dedicated segmentation or detection models, thereby naming this type of object segmentation as weakly-supervised.

References

[1] Karen Simonyan, Vedaldi Andrea, and Zisserman Andrew. Deep inside convolutional networks: Visualising image classification models and saliency maps. ICLR, 2013. https://arxiv.org/pdf/1312.6034.pdf

[2]https://github.com/sijoonlee/deep_learning/blob/master/cs231n/NetworkVisualization-PyTorch.ipynb

[3] https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a

[4] https://in.mathworks.com/help/deeplearning/ref/vgg19.html

[5] Y. Boykov and M. P. Jolly. Interactive graph cuts for optimal boundary and region segmentation of objects in N-D images. In Proc. ICCV, volume 2, pages 105–112, 2001.