Justifying Image Classification: What pixels were used to decide?

Source: Deep Learning on Medium


In this article I’m going to show you a simple way to reason about the predictions made by an image classification neural network model. I have provided you with the code to recreate my findings, but you don’t need to read the code to understand the article (hopefully).

The idea you should get from this article is that you want to have some understanding of what a neural model “sees” or “where it is looking” so that you can believe the predictions it makes with more confidence.

The basics

Let’s back up a second and review the basics at a super high level. A convolutional neural network model can be trained to categorize images based upon the stuff in the images, and we measure the performance of this image classifier with metrics like accuracy, precision, recall, and so forth. These models are so standard that you can basically download a pre-trained neural network model (e.g. inception-V4, VGG19, mobilenet, and so on). If the classes (things) you want to recognize are not in the list of stuff recognized by the pre-trained model, then you can usually retrain the pre-trained neural network model to recognize the new stuff, by using most of the weights and connections of the pre-trained model. That’s called transfer learning.

So, basically, we have a bunch of models out there that take in images and barf out labels. But why should I trust these models? Let’s have a look at what information one of these models is using to make predictions. This methodology can also be used to test out many other models including custom models.

Trust but verify

We got a nice mention in The Real-World AI Issue of The Verge talking about how these image classification systems can go wrong if you don’t think about the detail of what they are doing (full article here).

A neural network may learn to make predictions from the wrong part of the training data. This is the problem we are trying to eliminate by verifying the basis of the predictions from the model. If you want to build a classifier model to discriminate between wolves and dogs, you want to be sure that the model is not cheating by looking at the background for snow (wolves) and toys (dogs). We want to see what pixels are involved in the prediction. We don’t want dogs in snow to be seen as wolves. Think of this approach as Testdriven development (TDD) for Artificial Intelligence (AI). We are first defining a test (some picture from some obvious class) and then checking the model against this picture to see if the model does what we want it to using the information we think is relevant. However, rather than writing a test case, we just use our eyes on this one, to see what’s going on.

My original version of the article used a picture of Demi Rose, because why not… But the new version uses a stock photo bikini model, because copyright and fair use laws are crushing the creative world. OK. Venting over, and back to the article: We’ll use deepSHAP to see how VGG-16 (which uses the imagenet class labels) interprets an image.

Shap builds an explanation for the model prediction using information about the dataset, and the activations in the model itself (source: the shap github repo)

A real example

Here is a link to the code for this article, which you can directly throw into Google Collab:

The first part of the code installs deepSHAP and runs a part of one of the examples to show that the library is working correctly. Next we look at the shape of the input image (224, 224, 3) and see that the VGG-16 model takes in square images of height and width 224, having 3 color channels (red, green, and blue).

We now grab the image of the bikini model and crop it to the correct size of 224 by 224 by 3.

The image on the left was cropped to the right shape/size, using the image on the right as a starting point.

Since cv2 decided to be evil and use BGR encoding for color channels, there is an encoding step to switch RGB color data into BGR color data.

And now that the preprocessing is compete, let’s get to work. What classes (categories) does the model see in this picture? What pixels contribute or detract from each of these classifications opinions?

From left to right: The input image, the strongest prediction (cowboy_hat), followed by the second strongest prediction (sombrero), and so on. The scale on the bottom of the image shows relative positive or negative strength (the SHAP value). In each prediction image, the contribution of each pixel to the prediction is shown.

Clearly, looking at the above image at the prediction “cowboy_hat”, the cowboy hat used pixels from the hat to make the decision. That’s good. It also used several pixels from the bathing suit, but let’s dig a bit deeper. Why not block out parts of the image with white, black, or random noise, and see how that changes the prediction.

Prediction explanation when the top half of the image is whited out.
Prediction explanation the when top half of the image is blacked out.
Prediction explanation when the top half of the image is filled with random noise pixel values.

Let’s think about the above results. First, it is good to see that the top class prediction (what the model sees) was not sensitive to the type of masking we used (white, black, or noise). This indicates that the model cares about the data that’s real. Furthermore, the contribution of the irrelevant masked area to the decision making was very small, with a few exceptions (Band_Aid and sandal). Also, the predicted class maillot, is correct according to what we would expect as a human. “maillot” is an interesting word. I had to look it up. It’s pronounced more like “my-yo” rather than “mail-lot” and according to Google it’s “ a woman’s one-piece swimsuit.” Good to know. Bikini also makes good sense here. The rest of the classes are related to the image, and so that’s a good sign too. Notice that sandals are in there as well. My guess is that beach attire reminds the model of images with sandals. Or maybe the spotted pattern on the bra looks like a sandal? Maybe sarong and miniskirt often accompany skin, and we see the pixels that triggered those labels are the girl’s skin on her face and the shoulder area near the bra strap.

OK! So on we go to the top half of the image. Let’s do the same masking stuff and see what happens:

Prediction explanation when the bottom half of the image is whited out.
Prediction explanation when the bottom half of the image is blacked out.
Prediction explanation when the bottom half of the image is filled with random noise pixel values.

We see that the top predicted class cowboy_hat makes sense. There is a hat in the image, and the pixels of the face (especially the eye) probably help the network to know that the hat is on a head.

Conclusion

In this article you followed along to see a simple way to reason about the predictions made by an image classification neural network model. This was just one application of the amazing shap library. We had a look at what pixels in an input image contribute to a neural network model making its predictions. We saw that masking (blocking) part of the image changes the prediction, but the model’s top prediction did not change based on the type of masking. Critically, we observed that the locations of the pixels involved in the various predictions makes sense. You are welcome to use this approach to verify and interrogate your own neural network models.

If you liked this article, then have a look at some of my most read past articles, like “How to Price an AI Project” and “How to Hire an AI Consultant.” And hey, join our newsletter!

Until next time!

-Daniel
Lemay.ai
daniel@lemay.ai