Visualizing Convolution Neural Networks using Pytorch

Source: Deep Learning on Medium

Convolution Neural Network (CNN) is another type of neural network that can be used to enable machines to visualize things and perform tasks such as image classification, image recognition, object detection, instance segmentation etc…But the neural network models are often termed as ‘black box’ models because it is quite difficult to understand how the model is learning the complex dependencies present in the input. Also, it is difficult to analyze why a given prediction is made during inference.

In this article, we will look at two different types of visualization techniques such as :

  1. Visualizing learned filter weights.
  2. Performing occlusion experiments on the image.

These methods help us to understand what does filter learn? what kind of images cause certain neurons to fire? and how good are the hidden representations of the input image?.

Citation Note: The content and the structure of this article is based on the deep learning lectures from One-Fourth Labs — PadhAI. If you are interested checkout there course.

Receptive Field of Neuron

Before we go ahead and visualize the working of Convolution Neural Network, we will discuss the receptive field of filters present in the CNN’s.

Consider that we have a two-layered Convolution Neural Network and we are using 3×3 filters through the network. The centered pixel marked in the yellow present in Layer 2 is actually the result of applying convolution operation on the center pixel present in Layer 1 (by using 3×3 kernels and stride = 1). Similarly, the center pixel present in Layer 3 is a result of applying convolution operation on the center pixel present in Layer 2.

The receptive field of a neuron is defined as the region in the input image that can influence the neuron in a convolution layer i.e…how many pixels in the original image are influencing the neuron present in a convolution layer.

It is clear that the central pixel in Layer 3 depends on the 3×3 neighborhood of the previous layer (Layer 2). The 9 successive pixels (marked in pink) present in Layer 2 including the central pixel corresponds to the 5×5 region in Layer 1. As we go deeper and deeper in the network the pixels at the deeper layers will have a high receptive field i.e… the region of interest with respect to the original image would be larger.

From the above image, we can observe that the highlighted pixel present in the second convolution layer has a high receptive field with respect to the original input image.

Visualizing CNN

To visualize the working of CNN, we will explore two commonly used methods to understand how the neural network learns the complex relationships.

  1. Filter visualization with a pre-trained model.
  2. Occlusion analysis with a pre-trained model.

Run this notebook in Colab

All the code discussed in the article is present on my GitHub. You can open the code notebook with any setup by directly opening my Jupyter Notebook on Github with Colab which runs on Google’s Virtual Machine. Click here, if you just want to quickly open the notebook and follow along with this tutorial.

Don’t forget to upload the input images folder (can be downloaded from the Github Repo) onto Google Colab before executing the code in Colab.

Visualize Input Images

In this article, we will use a small subset of the ImageNet dataset with 1000 categories to visualize the filters of the model. The dataset can be downloaded from my GitHub repo.

To visualize the data set we will implement the custom function imshow.

The function imshow takes two arguments — image in tensor and the title of the image. First, we will perform the inverse normalization of the image with respect to the ImageNet mean and standard deviation values. After that, we will use matplotlib to display the image.

Sample Input Image

Filter Visualization

By visualizing the filters of the trained model, we can understand how CNN learns the complex Spatial and Temporal pixel dependencies present in the image.

What does a filter capture?

Consider that we have 2D input of size 4×4 and we are applying a filter of 2×2 (marked in red) on the image starting from the top left corner of the image. As we slide the kernel over the image from left to right and top to bottom to perform a convolution operation we would get an output that is smaller than the size of the input.

The output at each convolution operation (like h₁₄) is equal to the dot product of the input vector and a weight vector. We know that the dot product between the two vectors is proportional to the cosine of the angle between vectors.

During convolution operation, certain parts of the input image like the portion of the image containing the face of a dog might give high value when we apply a filter on top of it. In the above example, let’s discuss in what kind of scenarios our output h₁₄ will be high?.

The output h₁₄ would be high if the cosine value between the vectors is high i.e… cosine value should be equal to 1. If the cosine angle is equal to 1 then we know the angle between the vectors is equal to 0⁰. That means both input vector (portion of the image) X and the weight vector W are in the same direction the neuron is going to fire maximally.

The neuron h₁₄ will fire maximally when the input X (a portion of the image for convolution) is equal to the unit vector or a multiple of the unit vector in the direction of the filter vector W.

In other words, we can think of a filter as an image. As we slide the filter over the input from left to right and top to bottom whenever the filter coincides with a similar portion of the input, the neuron will fire. For all other parts of the input image that doesn’t align with the filter, the output will be low. This is the reason we call the kernel or weight matrix as a filter because it filters out portions of the input image that doesn’t align with the filter.

To understand what kind of patters does the filter learns, we can just plot the filter i.e… weights associated with the filter. For filter visualization, we will use Alexnet pre-trained with the ImageNet data set.

#alexnet pretrained with imagenet data
#import model zoo in torchvision
import torchvision.models as models
alexnet = models.alexnet(pretrained=True)

Alexnet contains 5 convolutional layers and 3 fully connected layers. ReLU is applied after every convolution operation. Remember that in convolution operation for 3D (RGB) images, there is no movement of kernel along with the depth since both kernel and image are of the same depth. We will visualize these filters (kernel) in two ways.

  1. Visualizing each filter by combing three channels as an RGB image.
  2. Visualizing each channel in a filter independently using a heatmap.

The main function to plot the weights is plot_weights. The function takes 4 parameters,

model — Alexnet model or any trained model

layer_num — Convolution Layer number to visualize the weights

single_channel — Visualization mode

collated — Applicable for single-channel visualization only.

In the plot_weights function, we take our trained model and read the layer present at that layer number. In Alexnet (Pytorch model zoo) first convolution layer is represented with a layer index of zero. Once we extract the layer associated with that index, we will check whether the layer is the convolution layer or not. Since we can only visualize layers which are convolutional. After validating the layer index, we will extract the learned weight data present in that layer.

#getting the weight tensor data
weight_tensor = model.features[layer_num]

Depending on the input argument single_channel we can plot the weight data as single-channel or multi-channel images. Alexnet’s first convolution layer has 64 filters of size 11×11. We will plot these filters in two different ways and understand what kind of patterns filters learn.

Visualizing Filters — Multi-Channel

In the case of single_channel = False we have 64 filters of depth 3 (RGB). we will combine each filter RGB channels into one RGB image of size 11x11x3. As a result, we would get 64 RGB images as the output.

#visualize weights for alexnet — first conv layer
plot_weights(alexnet, 0, single_channel = False)

From the above figure, we can see that each filter channel out of a total of 64 filters (0–63) is visualized separately. For eg. figure 0,0 indicate that the image represents the zeroth filter corresponding to the zeroth channel. Similarly, figure 0,1 indicates that the image represents the zeroth filter corresponding to the first channel and so on.

Visualizing the filter channels individually gives more intuition about what different filters are trying to learn based on the input data. By looking closely at the filter visualizations, it is clear that the patterns found in some of the channels from the same filter are different. That means not all channels present in a filter are trying to learn the same information from the input image. As we move deeper into the network the filter patterns more complex, they tend to capture high-level information like the face of a dog or cat.

As we go deeper and deeper into the network number of filters used for convolution increases. It is not possible for us to visualize all these filter channels individually either as a single image or each channel separately because of the large number of such filters. The second convolution layer of Alexnet (indexed as layer 3 in Pytorch sequential model structure) has 192 filters, so we would get 192*64 = 12,288 individual filter channel plots for visualization. Another way to plot these filters is to concatenate all these images into a single heatmap with a greyscale.

#plotting single channel images
plot_weights(alexnet, 0, single_channel = True, collated = True)
Filters from the first convolution layer in AlexNet — Collated Values
#plotting single channel images - second convolution layer
plot_weights(alexnet, 3, single_channel = True, collated = True)
Filters from the second convolution layer in AlexNet — Collated Values
#plotting single channel images - third convolution layer
plot_weights(alexnet, 6, single_channel = True, collated = True)
Filters from the third convolution layer in AlexNet — Collated Values

As you can see there are some interpretable features like edges, angles, and boundaries in the images from the first convolution layer. But as we go deeper into the network it becomes harder to interpret the filters.