Unboxing Black Box Models: Explainable AI for Deep Learning Models

Original article was published by Prarthana Saikia on Artificial Intelligence on Medium


Unboxing Black Box Models: Explainable AI for Deep Learning Models

From the above figure, we can see that as the complexity of the machine learning model increases, the accuracy increases. However, the interpretability reduces.

We need techniques which can have some level of interpretability for black-box models as well. This is where Explainable AI falls into place.

Deep Learning models are the most difficult to interpret, and the more complex and deeper is the neural network, the more difficult it is to offer any kind of explainability. Several approaches and techniques have been suggested in various studies to understand feature importance for a given prediction for Deep Learning models.

We will discuss few of the methods and their drawbacks below:

Perturbation-Based Forward Propagation Approaches

These approaches make perturbations or changes to individual inputs or neurons and observe the impact on neurons in the network and the output. This is repeated for different perturbations of the input, each time observing the change in the output for e.g. having a sliding window occlude different parts of the image and look at the change in the output. LIME also falls into this category, where a linear model is built to approximate the local behavior of the network using data gathered from making input perturbations. A common drawback to perturbation-based approaches is computational efficiency. Every time we make a perturbation, we have to recompute the outputs using a forward propagation. Another more subtle issue is saturation.

Saturation

Consider the output y such that y is a function of i₁ and i₂.

y=f(i,i)

When i₁+i₂<1, y increases in a linear fashion. When i₁ + i₂ is greater than 1, the output y saturates to 1.

Consider the point where i₁=1 and i₂=1. If at this point we change i₁ from 1 to 0, the output y wilI still be 1. Thus this perturbation method can give us a wrong answer that changing the input is not affecting the output. To avoid this saturation effect, we should perturb combinations of inputs which in turn increase computational cost.

Backpropagation-Based Approaches

Unlike perturbation methods, backpropagation approaches propagate an importance signal from an output neuron backwards through the layers to the input in one pass, making them efficient.

Using gradients to interpret neural networks

The equation for a linear regression model can be represented as

Where w1, w2, w3,…,wₙ are the weights of the features of the model.

This weight can be represented as partial derivative as-

In other words, the weight assigned to the ith feature tells us the gradient of that feature with respect to the model’s prediction: how the model’s prediction changes as the feature changes keeping other features constant.

However, using the gradient of the output with respect to the input works well for a linear model but quickly falls apart for nonlinear models.

Let us consider a neural network consisting only of a ReLU activation, with a baseline input of x=2.

Now, let’s consider a second data point, at x = -2.

ReLU(x=2) = 2, and ReLU(x=-2) = 0.

There is only one input feature in my model ‘x’. At x=-2, the output y changes to 0 compared to y=2 at baseline. This change in the output of my model has to be attributed to the change in x, since it is the only input feature to this model, but the gradient of ReLU(x) at the point x = -2 is 0. This gives us contradictory results.

Deep Learning Important FeaTures (DeepLIFT)

Deep LIFT is a Deep Learning explanation method proposed in the paper Learning Important Features Through Propagating Activation Differences.

DeepLIFT recognizes that what we care about is not the gradient, which describes how y changes as x changes at the point x, but the slope, which describes how y changes as x differs from the baseline. Basically, we need to define a baseline and identify how Y changes from baseline as x changes from the baseline. The ‘reference’ input represents some default or ‘neutral’ input that is chosen according to what is appropriate for the problem at hand.

This slope can be represented as:

This method tells us that the input feature ‘x’ has an importance value of -2.

Let t represent target output neuron and let x1, x2, …, xₙ represent neurons in some intermediate layer or set of layers that are necessary and sufficient to compute t.

Let t represent the reference activation of t. We define the quantity ∆t to be the difference-from-reference, that is

∆t = t−t₀

DeepLIFT assigns contribution scores C∆x∆t to ∆xᵢ where C∆x∆t can be thought of as the amount of difference-from-reference in t that is attributed to or ‘blamed’ on the difference-from-reference of xi

Deep SHAP (DeepLIFT + Shapley values)

Deep SHAP is an algorithm for SHAP values in deep learning models that builds on a connection with DeepLIFT described in Explaining Models by Propagating Shapley Values.

If we interpret the reference value as representing E[x] which what would be predicted if we did not know any features to the current output, then DeepLIFT approximates SHAP values assuming that the input features are independent of one another and the deep model is linear. However we use a background distribution instead of a baseline or single reference. SHAP values for each baseline in the background distribution are calculated and averaged out.

Implementation in Python

The data file train.csv contains gray-scale images of hand-drawn digits, from zero through nine.

Each image is 28 pixels in height and 28 pixels in width, for a total of 784 pixels in total. Each pixel has a single pixel-value associated with it, indicating the lightness or darkness of that pixel, with higher numbers meaning darker. This pixel-value is an integer between 0 and 255, inclusive.

The dataset has 785 columns. The first column, called “label”, is the digit that was drawn by the user. The rest of the columns contain the pixel-values of the associated image.

Our goal is to correctly identify and predict the digits from the handwritten images.

After performing the required pre-processing, we have augmented the images for better model performance and accuracy to increase the generalizability of the model. For example, we can obtain augmented data from the original images by applying simple geometric transforms, such as random: Translations, Rotations, Changes in scale, Shearing, Horizontal (and in some cases, vertical) flips.

We have trained a CNN model on the images dataset.

We are using SHAP library to run DeepExplainer on the trained model. We are analyzing the SHAP values of the top 10 images which has been wrongly predicted by our model.

The plot above shows the explanations for each class label from 0–9. The percentage values are the predicted class probabilities. We have used a SoftMax activation function. Hence, we have obtained the predicted probabilities of all the classes. Note that the explanations are ordered for the classes 0–9 going left to right along the rows

The blue colour is used to signify negative SHAP values whereas the red colour is for positive SHAP values.

Red pixels increase the probability of predicting the class while blue pixels decrease the probability of the class. For example, the image with true label 9 is predicted as 4 and not as 9 mainly because of the blue pixels for label 9 on the top as a small connecting line. Similarly, the image with true label 5 is predicted as 6. Because of the blue pixels for label 5 , the probability of the class decreases.

References:

  1. A Unified Approach to Interpreting Model Predictions by Scott M. Lundberg, Su-In Lee
  2. Learning Important Features Through Propagating Activation Differences by Avanti Shrikumar ,Peyton Greenside ,Anshul Kundaje
  3. Explainable CNN using SHAP https://www.kaggle.com/lomen0857/explaining-cnn-using-shap
  4. https://towardsdatascience.com/interpretable-neural-networks-45ac8aa91411