Explaining black box models-Ensemble and Deep Learning using LIME and SHAP

Source: Deep Learning on Medium

In this world of ever increasing data at a hyper pace, we use all kinds of complex ensemble and deep learning algorithms to achieve the highest possible accuracy. It’s sometimes magical how these models predict/classify/recognize/track on unknown data. And accomplishing this magic and more has been and would be the goal of the intensive research and development in the data science community. But, around all this great work, a question arises, can we always trust this prediction/classification /recognition/tracking? A variety of reasons, like lack of data, imbalanced datasets, biased datasets etc. can impact the decision rendered by the learning models. Models are gaining traction for explainability. Financial institutions and law agencies, for example demand explanations and evidences (SR 11–7 and The FUTURE of AI Act) bolstering the output of these learning models.

I am going to demonstrate explainability on the decisions made by LightGBM and Keras models in classifying a transaction for fraudulence on the IEEE CIS dataset . I would use two state of the art open source explainability techniques in this article, namely LIME (https://github.com/marcotcr/lime) and SHAP (https://github.com/slundberg/shap) from these research papers (1, 2). I have saved another great open source explainability technique — AIX360 (https://github.com/IBM/AIX360) for my next post as it proposes 8 novel algorithms for explainability.

I have borrowed the awesome feature engineering techniques on this dataset from here. Additionally, I have used feature scaling to even out the variability in the magnitude of feature values.

LIME

Intuitively, an explanation is a local linear approximation of the model’s behaviour. While the model may be very complex globally, it is easier to approximate it around the vicinity of a particular instance. While treating the model as a black box, LIME perturbs the instance desired to explain and learn a sparse linear model around it, as an explanation. The figure below illustrates the intuition for this procedure. The model’s decision function is represented by the blue/pink background, and is clearly nonlinear. The bright red cross is the instance being explained (let’s call it X). We sample instances around X, and weight them according to their proximity to X (weight here is indicated by size). We then learn a linear model (dashed line) that approximates the model well in the vicinity of X, but not necessarily globally. For more information, read this paper, or take a look at this blog post (https://github.com/marcotcr/lime).

LIME Technique (https://github.com/marcotcr/lime)

Ensemble model — LightGBM

Below is my model configuration. I have got an auc score of 0.972832 for this model.

# Create training and validation setsx_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=42)# Train the modelparameters = {
'application': 'binary',
'objective': 'binary',
'metric': 'auc',
'is_unbalance': 'true',
'boosting': 'gbdt',
'num_leaves': 31,
'feature_fraction': 0.5,
'bagging_fraction': 0.5,
'bagging_freq': 20,
'learning_rate': 0.05,
'verbose': 0
}
model = lightgbm.train(parameters,
train_data,
valid_sets=test_data,
num_boost_round=5000,
early_stopping_rounds=100)
y_pred = model.predict(x_test)

This the classification report of the above model.

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred_bool))

Before, I explore the formal LIME and SHAP explainability techniques to explain the model classification results, I thought why not use LightGBM’s inbuilt ‘feature importance’ function to visually understand the 20 most important features which helped the model lean towards a classification.

feature_imp= pd.DataFrame({'Value':model.feature_importance(),'Feature':X.columns})
plt.figure(figsize=(40, 20))
sns.set(font_scale = 5)
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value", ascending=False)[0:20])
plt.title('LightGBM Features (avg over folds)')
plt.tight_layout()
plt.savefig('lgbm_importances-01.png')
plt.show()