Explaining text classification predictions with LIME

A critical look at explaining text classifier predictions with LIME

Interpretability of Machine Learning models is a topic that has received a lot of interest in the last few years. The use cases are plentyful: from confounders in medical predictions, biased algorithms used for the selection of job applicants to potential future regulation that requires companies to explain why their AI made a particular decision.

This post looks at a particular technique called LIME, that aims to make individual predictions interpretable. The algorithm does so by fitting a locally-linear model to the predicted class-probability in the neighbourhood of the feature space where we desire our explanation. The simplicity of the linear model then makes it possible to explain how the prediction is affected by the features, locally.

See for more background: LIME paper. It is applied here to a simple, classical text classification problem, identical to the one used by the authors in their paper to demonstrate their approach.

The target of this post is:

  • To explain the inner workings of LIME
  • To show how to use it in Python on text, and lastly,
  • To highlight serious pitfalls/shortcomings of the algorithm

The latter is largely ignored on internet blogs and in the paper, apart from the author of this great document, and this article:

1. Brief explanation of how LIME works

The idea behind LIME is simple: making a local surrogate model in the neighbourhood of a prediction, that represents an optimum in the trade-off between interpretability (simplicity) and faithfulness (accuracy). The interpretable model can then tell us how the actual, original model works locally.

Sampling the feature space to create a local surrogate model. Image from the original LIME paper

The un-faithfulness (loss) is hereby measured using the sum of squared differences between the predicted probabilities of the original and the surrogate model in the vicinity of the prediction.

In other words, it aims to find a simple, usually linear, model with as few terms as possible that is still representative for the original model, at least locally.

For text, new instances are generated by random sampling of the words that are present in the observation. In other words, words are randomly left out from the input. The resulting near-by points are then fed into the classifier, and a cloud of predicted probabilities is gathered. A linear model is then fitted, and its coefficients are used for the explanation.

Although LIME promises to optimize between interpretability/simplicity and faithfulness (in an elegant equation in the paper), the algorithm does not do this for us. The number of coefficients (simplicity) is chosen by the user beforehand, and a linear regression is fitted to the samples. Furthermore, as we will see, the sampling, the choice of a kernel size that defines the locality and the regularization of the linear model can be problematic.

For images and tabular data, LIME -especially the definition of locality and the sampling- works quite differently, but that is out of the scope of this post.

2. Practical examples: Using LIME for text classification

In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn import metrics
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.datasets.base import Bunch
from sklearn import svm
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from lime.lime_text import LimeTextExplainer
from lime import lime_text

import matplotlib 
matplotlib.rc('xtick', labelsize=14) 
matplotlib.rc('ytick', labelsize=14) 
matplotlib.rc('axes', titlesize=16) 
matplotlib.rc('axes', labelsize=16) 
matplotlib.rc('figure', titlesize=20) 
/Users/ernstoldenhof/anaconda/envs/amld/lib/python3.6/site-packages/sklearn/utils/fixes.py:313: FutureWarning: numpy not_equal will not check object identity in the future. The comparison did not return the same result as suggested by the identity (`is`)) and will change.
  _nan_object_mask = _nan_object_array != _nan_object_array

The text data consists of a total of 1000 posts to two Google newsgroups: alt.atheism and soc.religion.cristian. The task is to predict which of the two groups a post is from. What makes the data so interesting is that both newsgroups discuss similar themes using the same words, but with an -obviously- rather different angle.

This should actually be a diffifult task for a simplistic "bag-of-words" style classifier. However, simple classifiers perform remarkably well on this task. This is because this dataset is "famous" for the presence of many confounding features, for instance the e-mail domains that are present in the headers. The classifier simply learns that a certain domain is only used in posts in one of the two classes.

Therefore, the great out-of-the box performance by simple text classifiers is not indicative of any real-world performance, since it learns to recognize particularities of this data set. Interpretability is thus essential to understand whether the model is any good or not.

To download the data, sci-kit learn comes with handy functionality. Two things to note here:

  • The ''' remove ''' argument makes it easy to include or get rid of those features
  • Removing them leaves us with several postings that should be excluded

I will everything in, for demonstration purposes. Removing headers and footers and quotes (that might be from the "opposite" newsgroup) is generally advisable however.

In [2]:
cats = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', shuffle=True, 
                                      categories=cats, remove=[]) #('headers', 'footers', 'quotes')
newsgroups_test = fetch_20newsgroups(subset='test', shuffle=True, 
                                     categories=cats, remove=[])#('headers', 'footers', 'quotes'))

class_names = ['atheism', 'christianity']
In [3]:
# Delete the empty posts from the train and test sets
def delete_empty_posts(data):
    Changes the passed argument in-place by removing elements from data.data
    numdeleted = 0
    for i, doc in enumerate(data.data):
        if len(doc.split()) <3:
            del data.data[i]
            data.target = np.delete(data.target, i)
            numdeleted += 1
            # print(doc)
    print('Deleted {} posts'.format(numdeleted))

Deleted 0 posts
Deleted 0 posts
In [4]:
['alt.atheism', 'soc.religion.christian']

We have about 1000 training samples, with a rather even class-distribution

2. a) Simple model: Tf-idf with MultinomialNB

Since we are going to use bag-of-words classifiers, we are going to need a vectorizer: an object that generates a vector for each text instance, indicating presence/absence or counts, of each word in the vocabulary. The Tf-idf vectorizer additionally gives more weight to less-frequently occurring words.

In [5]:
vectorizer = TfidfVectorizer(analyzer='word', token_pattern=r'\b[a-zA-Z]{3,}\b', lowercase=False,
                            min_df=5, max_df=0.7, stop_words='english')
# An alternative to play around with
vectorizer_small = TfidfVectorizer(analyzer='word', token_pattern=r'\b[a-zA-Z]{3,}\b', lowercase=True,
                            min_df=10, max_df=0.7, stop_words='english')

<1079x5089 sparse matrix of type '<class 'numpy.float64'>'
	with 107791 stored elements in Compressed Sparse Row format>

An important property of LIME is that it is "model-agnostic": it just needs an object with a '''.predict_proba()''' method that returns the probability of the positive class, and the instance that requires explanation. So we can use the whole family of sci-kit learn classifiers, but also pipelines, which is a powerful thing.

In [6]:
mnb = MultinomialNB(alpha=0.1)
p1 = make_pipeline(vectorizer, mnb)
In [7]:
# helper functions    
def test_classifier_performance(clf):
    clf will be fitted on the newsgroup train data, measured on test
    clf can be a sklearn pipeline

    clf.fit(newsgroups_train.data, newsgroups_train.target)
    pred = clf.predict(newsgroups_test.data)
    print('Accuracy: {:.3f}'.format(metrics.accuracy_score(newsgroups_test.target, pred)))
In [8]:
# Some basic parameter tuning
alpha_grid = np.logspace(-3, 0, 4)#Is smoothing parameter for the counts
param_grid = [{'multinomialnb__alpha': alpha_grid }]
gs = GridSearchCV(p1, param_grid=param_grid, cv=5, return_train_score=True)

gs.fit(newsgroups_train.data, newsgroups_train.target)
GridSearchCV(cv=5, error_score='raise-deprecating',
     steps=[('tfidfvectorizer', TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.float64'>, encoding='utf-8', input='content',
        lowercase=False, max_df=0.7, max_features=None, min_df=5,
        ngram_range=(1, 1), norm='l2', preprocessor=None, smoot...   vocabulary=None)), ('multinomialnb', MultinomialNB(alpha=0.1, class_prior=None, fit_prior=True))]),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid=[{'multinomialnb__alpha': array([ 0.001,  0.01 ,  0.1  ,  1.   ])}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring=None, verbose=0)
In [9]:
plt.plot(gs.cv_results_['param_multinomialnb__alpha'].data,gs.cv_results_['mean_test_score'], label='test')
plt.plot(gs.cv_results_['param_multinomialnb__alpha'].data,gs.cv_results_['mean_train_score'], label='train')
plt.title('Accuracy TF-IDF, Multinomial NB')

An impressively high accuracy, of over 97% on the test data. Only a minimal smoothing of the counts using pseudocounts is needed: An alpha of 1E-2 will do. Now, let's use LIME to interpret the results:

In [10]:
explainer = LimeTextExplainer(class_names=class_names) 
p1.fit(newsgroups_train.data, newsgroups_train.target)
exp = explainer.explain_instance(newsgroups_test.data[idx], 
print('Document id: %d' % idx)
print('Probability(christian) = {:.3f}'.format(p1.predict_proba([newsgroups_test.data[idx]])[0,1]))
print('True class: %s' % class_names[newsgroups_test.target[idx]])
print('R2 score: {:.3f}'.format(exp.score))
/Users/ernstoldenhof/anaconda/envs/amld/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
Document id: 35
Probability(christian) = 0.950
True class: christianity
R2 score: 0.692