Machine learning models usually perform really well for predictions, but are not interpretable. The iml package provides tools for analysing any black box machine learning model:
Feature importance: Which were the most important features?
Feature effects: How does a feature influence the prediction? (Accumulated local effects, partial dependence plots and individual conditional expectation curves)
Explanations for single predictions: How did the feature values of a single data point affect its prediction? (LIME and Shapley value)
Surrogate trees: Can we approximate the underlying black box model with a short decision tree?
The iml package works for any classification and regression machine learning model: random forests, linear models, neural networks, xgboost, etc.
This document shows you how to use the iml package to analyse machine learning models.
If you want to learn more about the technical details of all the methods, read chapters from: https://christophm.github.io/interpretable-ml-book/agnostic.html
We’ll use the MASS::Boston
dataset to demonstrate the
abilities of the iml package. This dataset contains median house values
from Boston neighbourhoods.
#> crim zn indus chas nox rm age dis rad tax ptratio black lstat
#> 1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98
#> 2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14
#> 3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03
#> 4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94
#> 5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33
#> 6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21
#> medv
#> 1 24.0
#> 2 21.6
#> 3 34.7
#> 4 33.4
#> 5 36.2
#> 6 28.7
First we train a randomForest to predict the Boston median housing value:
We create a Predictor
object, that holds the model and
the data. The iml package uses R6 classes: New objects can be created by
calling Predictor$new()
.
We can measure how important each feature was for the predictions
with FeatureImp
. The feature importance measure works by
shuffling each feature and measuring how much the performance drops. For
this regression task we choose to measure the loss in performance with
the mean absolute error (‘mae’), another choice would be the mean
squared error (‘mse’).
Once we create a new object of FeatureImp
, the
importance is automatically computed. We can call the
plot()
function of the object or look at the results in a
data.frame.
#> Warning: Paket 'future.apply' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'future' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'jsonlite' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'dplyr' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'tidyselect' wurde unter R Version 4.3.3 erstellt
#> Warning: Paket 'globals' wurde unter R Version 4.3.3 erstellt
#> Warning: Paket 'scales' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'yaml' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'ggplot2' wurde unter R Version 4.3.3 erstellt
#> Warning: Paket 'knitr' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'checkmate' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'bslib' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'utf8' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'xfun' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'sass' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'digest' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'lifecycle' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'vctrs' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'data.table' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'evaluate' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'glue' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'listenv' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'parallelly' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'fansi' wurde unter R Version 4.3.2 erstellt
#> Warning: Paket 'htmltools' wurde unter R Version 4.3.2 erstellt
#> feature importance.05 importance importance.95 permutation.error
#> 1 lstat 3.996140 4.135837 4.317803 4.788590
#> 2 rm 2.685506 2.715282 2.889989 3.143832
#> 3 ptratio 1.666459 1.728166 1.745227 2.000920
#> 4 dis 1.612628 1.667092 1.699770 1.930207
#> 5 nox 1.504569 1.576274 1.581755 1.825055
#> 6 crim 1.509768 1.545514 1.611800 1.789441
#> 7 indus 1.297713 1.337473 1.356928 1.548565
#> 8 age 1.279583 1.302343 1.346693 1.507890
#> 9 tax 1.220118 1.226805 1.256486 1.420430
#> 10 black 1.177398 1.209823 1.232926 1.400768
#> 11 rad 1.054973 1.065001 1.070641 1.233089
#> 12 zn 1.022191 1.034042 1.039069 1.197244
#> 13 chas 1.010353 1.014358 1.018517 1.174453
Besides knowing which features were important, we are interested in
how the features influence the predicted outcome. The
FeatureEffect
class implements accumulated local effect
plots, partial dependence plots and individual conditional expectation
curves. The following plot shows the accumulated local effects (ALE) for
the feature ‘lstat’. ALE shows how the prediction changes locally, when
the feature is varied. The marks on the x-axis indicates the
distribution of the ‘lstat’ feature, showing how relevant a region is
for interpretation (little or no points mean that we should not
over-interpret this region).
If we want to compute the partial dependence curves on another feature, we can simply reset the feature:
We can also measure how strongly features interact with each other. The interaction measure regards how much of the variance of \(f(x)\) is explained by the interaction. The measure is between 0 (no interaction) and 1 (= 100% of variance of \(f(x)\) due to interactions). For each feature, we measure how much they interact with any other feature:
#> Warning: Paket 'withr' wurde unter R Version 4.3.2 erstellt
#>
#> Attache Paket: 'withr'
#> Das folgende Objekt ist maskiert 'package:tools':
#>
#> makevars_user
#> Die folgenden Objekte sind maskiert von 'package:rlang':
#>
#> local_options, with_options
We can also specify a feature and measure all it’s 2-way interactions with all other features:
You can also plot the feature effects for all features at once:
Another way to make the models more interpretable is to replace the black box with a simpler model - a decision tree. We take the predictions of the black box model (in our case the random forest) and train a decision tree on the original features and the predicted outcome. The plot shows the terminal nodes of the fitted tree. The maxdepth parameter controls how deep the tree can grow and therefore how interpretable it is.
#> Lade nötiges Paket: partykit
#> Lade nötiges Paket: libcoin
#> Warning: Paket 'libcoin' wurde unter R Version 4.3.2 erstellt
#> Lade nötiges Paket: mvtnorm
#> Warning: Paket 'mvtnorm' wurde unter R Version 4.3.2 erstellt
We can use the tree to make predictions:
#> Warning in self$predictor$data$match_cols(data.frame(newdata)): Dropping
#> additional columns: medv
#> .y.hat
#> 1 28.78541
#> 2 21.91311
#> 3 28.78541
#> 4 28.78541
#> 5 28.78541
#> 6 28.78541
Global surrogate model can improve the understanding of the global
model behaviour. We can also fit a model locally to understand an
individual prediction better. The local model fitted by
LocalModel
is a linear regression model and the data points
are weighted by how close they are to the data point for wich we want to
explain the prediction.
#> Lade nötiges Paket: glmnet
#> Lade nötiges Paket: Matrix
#> Warning: Paket 'Matrix' wurde unter R Version 4.3.2 erstellt
#> Loaded glmnet 4.1-8
#> Lade nötiges Paket: gower
#> beta x.recoded effect x.original feature feature.value
#> rm 4.5149417 6.575 29.685741 6.575 rm rm=6.575
#> ptratio -0.5696891 15.300 -8.716243 15.3 ptratio ptratio=15.3
#> lstat -0.4592951 4.980 -2.287290 4.98 lstat lstat=4.98
An alternative for explaining individual predictions is a method from coalitional game theory named Shapley value. Assume that for one data point, the feature values play a game together, in which they get the prediction as a payout. The Shapley value tells us how to fairly distribute the payout among the feature values.
We can reuse the object to explain other data points:
The results in data.frame form can be extracted like this:
#> feature phi phi.var feature.value
#> 1 crim 0.02236619 1.91975348 crim=0.02731
#> 2 zn -0.01757000 0.17093081 zn=0
#> 3 indus -0.54826667 2.34293197 indus=7.07
#> 4 chas -0.02603333 0.01361733 chas=0
#> 5 nox 0.44874333 1.35829579 nox=0.469
#> 6 rm -1.54282444 11.60060925 rm=6.421