Interpretability in Deep Learning with W&B — CAM and GradCAM

Original article was published on Deep Learning on Medium

Interpretability in Deep Learning with W&B — CAM and GradCAM

This report will review how Grad-CAM counters the common criticism that neural networks are not interpretable.

View interactive report here. All the code is available here.

Training a classification model is interesting, but have you ever wondered how your model is making its predictions? Is your model actually looking at the dog in the image before classifying it as a dog with 98% accuracy? Interesting, isn’t it. In today’s report, we will explore why deep learning models need to be interpretable, and some interesting methods to peek under the hood of a deep learning model. Deep learning interpretability is a very exciting area of research and much progress is being made in this direction already.

So why should you care about interpretability? After all, the success of your business or your project is judged primarily by how good the accuracy of your model is. But in order to deploy our models in the real world, we need to consider other factors too. For instance, is racially biased? Or, what if it’s classifying humans with 97% accuracy, but while it classifies men with 99% accuracy, it only achieves 95% accuracy on women?

Understanding how a model makes its predictions can also help us debug your network. [Check out this blog post on ‘Debugging Neural Networks with PyTorch and W&B Using Gradients and Visualizations’ for some other techniques that can help].

At this point, we are all familiar with the concept that deep learning models make predictions based on the learned representation expressed in terms of other simpler representations. That is, deep learning allows us to build complex concepts out of simpler concepts. Here’s an amazing Distill Pub post to help you understand this concept better. We also know that these representations are learned while we train the model with our input data and the label, in case of some supervised learning task like image classification. One of the criticisms of this approach is that the learned features in a neural network are not interpretable.

Today we’ll look at 2techniques that address this criticism and shed light into neural networks’ “black-box” nature of learning.

  • Class Activation Map(CAM)
  • Gradient CAM

Class Activation Maps

It has been observed that convolution units of various layers of a convolutional neural network act as an object detector even though no such prior about the location of the object is provided while training the network for a classification task. Even though convolution has this remarkable property, it is lost when we use a fully connected layer for the classification task. To avoid the use of a fully connected network some architectures like Network in Network(NiN) and GoogLeNet are fully convolutional neural networks.

Global Average Pooling(GAP) is a very commonly used layer in such architectures. It is mainly used as a regularizer to prevent overfitting while training. The authors of Learning Deep Features for Discriminative Localization found out that by tweaking such an architecture, they can extend the advantages of GAP and can retain its localization ability until the last layer. Let’s try to quickly understand the procedure of generating CAM using GAP.

The class activation map simply indicates the discriminative region in the image which the CNN uses to classify that image in a particular category. For this technique, the network consists of ConvNet and just before the Softmax layer(for multi-class classification), global average pooling is performed on the convolutional feature maps. The output of this layer is used as features for a fully-connected layer that produces the desired classification output. Given this simple connectivity structure, we can identify the importance of the image regions by projecting back the weights of the output layer on to the convolutional feature maps.

Figure 1: The network architecture ideal for CAM (Source)

Let’s try to implement this. 😄

Step 1: Modify Your Model

Suppose you have built your deep classifier with Conv blocks and a few fully connected layers. We will have to modify this architecture such that there aren’t any fully connected layers. We will use the GlobalAveragePooling2D layer between the output layer (softmax/sigmoid) and the last convolutional block.

The CAMmodel provides a required modification to our cat and dog classifier. Here I am using pre-trained VGG16 model to simulate my already trained cat-dog classifier.

def CAMmodel():
## Simulating my pretrained dog and cat classifier.
vgg = VGG16(include_top=False, weights='imagenet')
vgg.trainable = False
## Flatten the layer so that it's not nested in the sequential model.
vgg_flat = flatten_model(vgg)
## Insert GAP
vgg_flat.append(keras.layers.GlobalAveragePooling2D())
vgg_flat.append(keras.layers.Dense(1, activation='sigmoid'))

model = keras.models.Sequential(vgg_flat)
return model

A simple utility flatten_model returns the list of layers in my pre-trained model. This is done so that the layers are not nested when modified using Sequential model and the last convolutional layer can be accessed and used as an output. I appended GlobalAveragePooling2D and Dense in the returned array from flatten_model. Finally, the Sequential model is returned.

def flatten_model(model_nested):
'''
Utility to flatten pretrained model
'''
layers_flat = []
for layer in model_nested.layers:
try:
layers_flat.extend(layer.layers)
except AttributeError:
layers_flat.append(layer)

return layers_flat

Next we call model.build() with the appropriate model input shape.

keras.backend.clear_session()
model = CAMmodel()
model.build((None, None, None, 3)) # Note
model.summary()

Step 2: Retrain your model with CAMLogger callback

Since a new layer was introduced, we have to retrain the model. But we don’t need to retrain the entire model. We can freeze the convolutional blocks by using vgg.trainable=False.

Figure 2: Metrics plot after retraining model.

Observations:

  • There is a decline in the model performance in terms of both training and validation accuracy. The optimal train and validation accuracy that I achieved was 99.01% and 95.67% respectively.
  • Thus for the implementation of CAM, we have to modify our architecture and thus a decline in model performance.

Step 3: Use CAMLogger to see Class Activation Map

In the __init__for the CAM class, we initialize cammodel. Notice there are two outputs from this cammodel:

  • Output from the last convolutional layer (block5_conv3 here)
  • The model prediction (softmax/sigmoid).
class CAM:
def __init__(self, model, layerName):
self.model = model
self.layerName = layerName
## Prepare cammodel
last_conv_layer = self.model.get_layer(self.layerName).output
self.cammodel = keras.models.Model(inputs=self.model.input,
outputs=[last_conv_layer, self.model.output])

def compute_heatmap(self, image, classIdx):
## Get the output of last conv layer and model prediction
[conv_outputs, predictions] = self.cammodel.predict(image)
conv_outputs = conv_outputs[0, :, :, :]
conv_outputs = np.rollaxis(conv_outputs, 2)
## Get class weights between
class_weights = self.model.layers[-1].get_weights()[0]
## Create the class activation map.
caml = np.zeros(shape = conv_outputs.shape[1:3], dtype=np.float32)
for i, w in enumerate(class_weights[:]):
caml += w * conv_outputs[i, :, :]
caml /= np.max(caml)
caml = cv2.resize(caml, (image.shape[1], image.shape[2]))
## Prepare heat map
heatmap = cv2.applyColorMap(np.uint8(255*caml), cv2.COLORMAP_JET)
heatmap[np.where(caml < 0.2)] = 0
return heatmap
def overlay_heatmap(self, heatmap, image):
img = heatmap*0.5 + image
img = img*255
img = img.astype('uint8')
return (heatmap, img)

The compute_heatmap method is responsible for generating the heatmap which is the discriminative region used by CNN to identify the category (class of image).

  • cammodel.predict() on the input image will give the feature map of the last convolutional layer of shape (1,7,7,512).
  • We also extract the weights of the output layer of shape (512,1).
  • Finally, the dot product of the extracted weights from the final layer and the feature map is calculated to produce the class activation map.

Now we wrap everything in a callback. The CamLogger callback integrates wandb.log() method to log the generated activation maps onto the W&B run page. The heatmap returned from the CAM is finally overlayed on the original image by calling overlay_heatmap() method.

Step 4: Draw conclusions from the CAM

We can draw lot of conclusions from the the plots as shown below. 👇 Note the examples chart contains validation images along with their prediction scores. If the prediction score is greater than 0.5, the network classifies the image as a dog, otherwise as a cat. While CAM charts have their corresponding class activation maps. Let’s go through some observations:

  • The model is classifying the images as dogs by looking at the facial region in the image. For some images it’s able to look at the entire body, except the paws.
  • The model is classifying the images as cats by looking at the ears, paws and whiskers.
Figure 3: Looking at dog’s face and cat’s whiskers, paws and ears.
  • For a misclassified image the model is not looking at where it should be looking. Thus by using CAM we are able to interpret the reason behind this misclassification, which is really cool.
Figure 4: CNN was looking at something else.

Why is that? Even though the ears, paws and whiskers are present in the image why did it look at something else. One reason I can think of is that since we haven’t fine tuned our pretrained VGG16 on our cat-dog dataset, the CNN as feature extractor is not entirely familiar with the patterns (distributions) appearing in our dataset.

  • When multiple instances of the same class are present in the image, the model looks only at one of them. But that is okay, given that we are not concerned about object detection. Note that the confidence is low because of this.
Figure 5: Only looking at one occurrence.

Other use cases:

CAM can be used for a weakly supervised object localization task. The authors of the linked paper tested the ability of the CAM for a localization task on the ILSVRC 2014 benchmark dataset. The technique was able to achieve 37.1% top-5 error for object localization on this dataset, which is close to the 34.2% top-5 error achieved by a fully supervised CNN approach.

Figure 6: Some more examples. Click here for more such examples.

Gradient-Weighted Class Activation Maps

Even though CAM was amazing it had some limitations:

  • The model needs to be modified in order to use CAM.
  • The modified model needs to be retrained, which is computationally expensive.
  • Since fully connected Dense layers are removed. the model performance will surely suffer. This means the prediction score doesn’t give the actual picture of the model’s ability.
  • The use case was bound by architectural constraints, i.e., architectures performing GAP over convolutional maps immediately before output layer.

What makes a good visual explanation?:

  • Certainly the technique should localize the class in the image. We saw this in CAM and it was worked remarkable good.
  • Finer details should be captured, i.e., the activation map should be high resolution.

Thus the authors of Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, a really amazing paper, came up with modifications to CAM and previous approaches. Their approach uses the gradients of any target prediction flowing into the final convolutional layer to produce a coarse localization map highlighting the important regions in the image for predicting the class of the image.

Thus Grad-CAM is a strict generalization over CAM. Beside overcoming the limitations of CAM it’s applicable to different deep learning tasks involving CNNs. It is applicable to:

  • CNNs with fully-connected layers (e.g. VGG) without any modification to the network.
  • CNNs used for structured outputs like image captioning.
  • CNNs used in tasks with multi-modal inputs like visual Q&A or reinforcement learning, without architectural changes or re-training.
Figure 7: Grad-CAM overview (Source)

Let’s implement this 😄

Step 1: Your Deep Learning Task

We will focus on the image classification task. Unlike CAM we don’t have to modify our model for this task and retrain it.

I have used a VGG16 model pretrained on ImageNet as my base model and I’m simulating Transfer Learning with this.

The layers of the baseline model are turned to non-trainable by using vgg.trainable = False. Note how I have used fully connected layers in the model.

def catdogmodel():
inp = keras.layers.Input(shape=(224,224,3))
vgg = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_tensor=inp,
input_shape=(224,224,3))
vgg.trainable = False

x = vgg.get_layer('block5_pool').output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(64, activation='relu')(x)
output = keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.models.Model(inputs = inp, outputs=output)

return model

You will find the class GradCAM in the linked notebook. This is a modified implementation from Grad-CAM: Visualize class activation maps with Keras, TensorFlow, and Deep Learning, an amazing blog post, by Adrian Rosebrook of PyImageSearch.com. I would highly suggest checking out the step by step implementation of the GradCAM class in that blog post.

I made two modifications to it:

  • While doing transfer learning, that is, if your target (last) convolutional layer is non trainable, tape.gradient(loss, convOutputs) will return None. This is because tape.gradient() by default does not trace non-trainable variables/layers. Thus to use that layer for computing your gradients you need to allow GradientTape to watch it by calling tape.watch() on the target layer output (tensor). Hence the change,
with tf.GradientTape() as tape:
tape.watch(self.gradModel.get_layer(self.layerName).output)
inputs = tf.cast(image, tf.float32)
(convOutputs, predictions) = self.gradModel(inputs)
  • The original implementation didn’t account for binary classification. The original authors also talked about softmax-ing the output. So in order to train a simple cat and dog classifier, I made a small modification. Hence the change,
if len(predictions)==1:
# Binary Classification
loss = predictions[0]
else:
loss = predictions[:, classIdx]

The GRADCAM class can be used after the model is trained or as a callback. Here’s a small excerpt from his blog post.

The third point motivated me to work on this project. I built a custom callback around this GRADCAM implementation and used wandb.log() to log the activation maps. Thus by using this callback you can use GradCAM while training.

Step 3: Use GRADCamLogger and train

Given we’re working with a simple dataset I have only trained for few epochs and the model seems to work well.

Here’s the GradCAM custom callback.

class GRADCamLogger(tf.keras.callbacks.Callback):
def __init__(self, validation_data, layer_name):
super(GRADCamLogger, self).__init__()
self.validation_data = validation_data
self.layer_name = layer_name
def on_epoch_end(self, logs, epoch):
images = []
grad_cam = []
## Initialize GRADCam Class
cam = GradCAM(model, self.layer_name)
for image in self.validation_data:
image = np.expand_dims(image, 0)
pred = model.predict(image)
classIDx = np.argmax(pred[0])

## Compute Heatmap
heatmap = cam.compute_heatmap(image, classIDx)

image = image.reshape(image.shape[1:])
image = image*255
image = image.astype(np.uint8)
## Overlay heatmap on original image
heatmap = cv2.resize(heatmap, (image.shape[0],image.shape[1]))
(heatmap, output) = cam.overlay_heatmap(heatmap, image, alpha=0.5)
images.append(image)
grad_cam.append(output)
wandb.log({"images": [wandb.Image(image)
for image in images]})
wandb.log({"gradcam": [wandb.Image(cam)
for cam in grad_cam]})
Figure 8: Metrics after training.

Step 4: Draw conclusions from the GradCAM

GradCAM being a strict generalization over CAM, should be preferred over CAM. To understand the theoretical underpinnings of this technique I recommend reading Demystifying Convolutional Neural Networks using GradCam by Divyanshu Mishra or simply reading the linked paper. A couple interesting conclusions we can draw include:

  • The model looks at the face of the dogs to classify them correctly, while I am unsure about the cat.
Figure 8: Looking at dog’s face.
  • The model is able to localize multiple instances of the class in an image, i.e. the prediction score is accounting for multiple dogs and cats in the image.
Figure 9: Looking at multiple occurrence.
Figure 10: Some more examples. Click here for more such examples.

Conclusion

Class Activation Maps and Grad-CAMs are a few approaches that introduce some explainability/interpretability into deep learning models, and are quite widely used. What’s most fascinating about these techniques is the ability to perform the object localization task, even without training the model with a location prior. GradCAM, when used for image captioning, can help us understand what region in the image is used to generate a certain word. When used for a Visual Q&A task, it can help us understand why the model came to a particular answer. Even though Grad-CAM is class-discriminative and localizes the relevant image regions, it lacks the ability to highlight fine-grained details the way pixel-space gradient visualization methods like Guided backpropagation, and Deconvolution do. Thus the authors combined Grad-CAM with Guided backpropagation.

Thanks for reading this report until the end. I hope you find the callbacks introduced here helpful for your deep learning wizardry. Please feel free to reach out to me on Twitter(@ayushthakur0) for any feedback on this report. Thank you.