Explain that tidymodels blackbox!

Let’s explain a {tidymodels} random forest by classic explainability methods (permutation importance, partial dependence plots (PDP), Friedman’s H statistics), and also fancy SHAP.

Disclaimer: {hstats}, {kernelshap} and {shapviz} are three of my own packages.

Diabetes data

We will use the diabetes prediction dataset of Kaggle to model diabetes (yes/no) as a function of six demographic features (age, gender, BMI, hypertension, heart disease, and smoking history). It has 100k rows.

Note: The data additionally contains the typical diabetes indicators HbA1c level and blood glucose level, but we wont use them to avoid potential causality issues, and to gain insights also for people that do not know these values.

# https://www.kaggle.com/datasets/iammustafatz/diabetes-prediction-dataset

library(tidyverse)
library(tidymodels)
library(hstats)
library(kernelshap)
library(shapviz)
library(patchwork)

df0 <- read.csv("diabetes_prediction_dataset.csv")  # from above Kaggle link
dim(df0)  # 100000 9
head(df0)
# gender age hypertension heart_disease smoking_history   bmi HbA1c_level blood_glucose_level diabetes
# Female  80            0             1           never 25.19         6.6                 140        0
# Female  54            0             0         No Info 27.32         6.6                  80        0
#   Male  28            0             0           never 27.32         5.7                 158        0
# Female  36            0             0         current 23.45         5.0                 155        0
#   Male  76            1             1         current 20.14         4.8                 155        0
# Female  20            0             0           never 27.32         6.6                  85        0

summary(df0)
anyNA(df0)  # FALSE
table(df0$smoking_history, useNA = "ifany")

# DATA PREPARATION

# Note: tidymodels needs a factor response for classification
df1 <- df0 |>
  transform(
    y = factor(diabetes, levels = 0:1, labels = c("No", "Yes")),
    female = (gender == "Female") * 1,
    smoking_history = factor(
      smoking_history, 
      levels = c("No Info", "never", "former", "not current", "current", "ever")
    ),
    bmi = pmin(bmi, 50)
  )

# UNIVARIATE ANALYSIS

ggplot(df1, aes(diabetes)) +
  geom_bar(fill = "chartreuse4")

df1  |>  
  select(age, bmi, HbA1c_level, blood_glucose_level) |> 
  pivot_longer(everything()) |> 
  ggplot(aes(value)) +
  geom_histogram(fill = "chartreuse4", bins = 19) +
  facet_wrap(~ name, scale = "free_x")

ggplot(df1, aes(smoking_history)) +
  geom_bar(fill = "chartreuse4")

df1 |> 
  select(heart_disease, hypertension, female) |>
  pivot_longer(everything()) |> 
  ggplot(aes(name, value)) +
  stat_summary(fun = mean, geom = "bar", fill = "chartreuse4") +
  xlab(element_blank())
“yes” proportion of binary variables (including the response)
Distribution of numeric variables
Distribution of smoking_history

Modeling

Let’s fit a random forest via tidymodels with {ranger} backend.

We add a predict function pf() that outputs only the probability of the “Yes” class.

set.seed(1)
ix <- initial_split(df1, strata = diabetes, prop = 0.8)
train <- training(ix)
test <- testing(ix)

xvars <- c("age", "bmi", "smoking_history", "heart_disease", "hypertension", "female")

rf_spec <- rand_forest(trees = 500) |> 
  set_mode("classification") |> 
  set_engine("ranger", num.threads = NULL, seed = 49)

rf_wf <- workflow() |> 
  add_model(rf_spec) |>
  add_formula(reformulate(xvars, "y"))

model <- rf_wf |> 
    fit(train)

# predict() gives No/Yes columns
predict(model, head(test), type = "prob")
# .pred_No .pred_Yes
#    0.981    0.0185

# We need to extract only the "Yes" probabilities
pf <- function(m, X) {
  predict(m, X, type = "prob")$.pred_Yes
}
pf(model, head(test))  # 0.01854290 ...

Classic explanation methods

# 4 times repeated permutation importance wrt test logloss
imp <- perm_importance(
  model, X = test, y = "diabetes", v = xvars, pred_fun = pf, loss = "logloss"
)
plot(imp) +
  xlab("Increase in test logloss")

# Partial dependence of age
partial_dep(model, v = "age", train, pred_fun = pf) |> 
  plot()

# All PDP in one patchwork
p <- lapply(xvars, function(x) plot(partial_dep(model, v = x, X = train, pred_fun = pf)))
wrap_plots(p) &
  ylim(0, 0.23) &
  ylab("Probability")

# Friedman's H stats
system.time( # 20 s
  H <- hstats(model, train[xvars], approx = TRUE, pred_fun = pf)
)
H  # 15% of prediction variability comes from interactions
plot(H)

# Stratified PDP of strongest interaction
partial_dep(model, "age", BY = "bmi", X = train, pred_fun = pf) |> 
  plot(show_points = FALSE)

Feature importance

Permutation importance measures by how much the average test loss (in our case log loss) increases when a feature is shuffled before calculating the losses. We repeat the process four times and also show standard errors.

Permutation importance: Age and BMI are the two main risk factors.

Main effects

Main effects are estimated by PDP. They show how the average prediction changes with a feature, keeping every other feature fixed. Using a fixed vertical axis helps to grasp the strenght of the effect.

PDPs: The diabetes risk tends to increase with age, high (and very low) BMI, presence of heart disease/hypertension, and it is a bit lower for females and non-smoker.

Interaction strength

Interaction strength can be measured by Friedman’s H statistics, see the earlier blog post. A specific interaction can then be visualized by a stratified PDP.

Friedman’s H statistics: Left: BMI and age are the two features with clearly strongest interactions. Right: Their pairwise interaction explains about 10% of their joint effect variability.
Stratified PDP: The strong interaction between age and BMI is clearly visible. A high BMI makes the age effect on diabetes stronger.

SHAP

What insights does a SHAP analysis bring?

We will crunch slow exact permutation SHAP values via kernelshap::permshap(). If we had more features, we could switch to

  • kernelshap::kernelshap()
  • Brandon Greenwell’s {fastshap}, or to the
  • {treeshap} package of my colleages from TU Warsaw.
set.seed(1)
X_explain <- train[sample(1:nrow(train), 1000), xvars]
X_background <- train[sample(1:nrow(train), 200), ]

system.time(  # 10 minutes
  shap_values <- permshap(model, X = X_explain, bg_X = X_background, pred_fun = pf)
)
shap_values <- shapviz(shap_values)
shap_values  # 'shapviz' object representing 1000 x 6 SHAP matrix
saveRDS(shap_values, file = "shap_values.rds")
# shap_values <- readRDS("shap_values.rds")

sv_importance(shap_values, show_numbers = TRUE)
sv_importance(shap_values, kind = "bee")
sv_dependence(shap_values, v = xvars) &
  ylim(-0.14, 0.24) &
  ylab("Probability")

SHAP importance

SHAP importance: On average, the age increases or decreases the diabetes probability by 4.7% etc. In this case, the top three features are the same as in permutation importance.

SHAP “summary” plot

SHAP “summary” plot: Additionally to the bar plot, we see that higher age, higher BMI, hypertension, smoking, males, and having a heart disease are associated with higher diabetes risk.

SHAP dependence plots

SHAP dependence plots: We see similar shapes as in the PDPs. Thanks to the vertical scatter, we can, e.g., spot that the BMI effect strongly depends on the age. As in the PDPs, we have selected a common vertical scale to also see the effect strength.

Final words

  • {hstats}, {kernelshap} and {shapviz} can explain any model with XAI methods like permutation importance, PDPs, Friedman’s H, and SHAP. This, obviously, also includes models developed with {tidymodels}.
  • They would actually even work for multi-output models, e.g., classification with more than two categories.
  • Studying a blackbox with XAI methods is always worth the effort, even if the methods have their issues. I.e., an imperfect explanation is still better than no explanation.
  • Model-agnostic SHAP takes a little bit of time, but it is usually worth the effort.

The full R script


Posted

in

, ,

by

Tags:

Comments

13 responses to “Explain that tidymodels blackbox!”

  1. Rakesh Poduval Avatar
    Rakesh Poduval

    Very insightful.

  2. Milan Avatar
    Milan

    I apologise – I ask without simply first trying myself – is this also fully applicable to regression models?

    1. Michael Mayer Avatar

      Yes it is. The code will even simplify (no fight with factor response, no need to specify a custom predict, etc)

  3. Hannes Avatar

    Great post! I wasn’t aware of hstats but I’ll definitely use it for partial dependency plots in the future!

  4. Carlos Ortega Avatar
    Carlos Ortega

    Thanks for the very clear explanations.

    Regarding the speed of calculations when dealing with different sizes of data, for the different packages, what is your experience?

    1. Michael Mayer Avatar

      Also good question. The bottleneck is the speed of the “predict()” function: Random forests are among the slowest, linear models and boosted trees among the fastest: E.g., hstats() on a comparable XGBoost model would take ~2 instead of 20 seconds. With XGBoost, you would also be able to switch to TreeSHAP to run the SHAP calculation within 1 second. The size of the data is less relevant as most methods use subsampling for the expensive steps (defaults can be changed). For SHAP, you need to select both the explanation data (and for non-TreeSHAP also the background data) yourself.

      1. Carlos Ortega Avatar
        Carlos Ortega

        Thanks Michael!.

  5. […] last post was using {hstats}, {kernelshap} and {shapviz} to explain a binary classification random forest. […]

  6. Dadong Li Avatar
    Dadong Li

    Thanks for the great post, Michael. One quick question, in the current random forest model, are interactions of variables automatically included in the step of “rf_spec” or do we have to manually indicated it like in other models?

    1. Michael Mayer Avatar

      A tree-based model adds interaction effects implicitly, so you don’t need to specify this yourself. In {ranger} you can, additionally, specify explicit interactions. You can do this via the “*” in the formula, but this would be rather exotic.

  7. Elle Avatar
    Elle

    Thanks for a great post! When creating the SHAP beeswarm plot, I am getting datapoints that are grey. Would this be due to missing data in my training dataset (which I go on to apply imputation, one-hot encoding and normalisation using a recipe). If so, when creating X_explain and X_background, should I be using the pre-processed training data?

    1. Michael Mayer Avatar

      Yeah, that might be missing feature values. Depending on the model and situation, using raw or using preprocessing data is better. Actually, I often use two preprocessing steps: the first step produces a dataset ideal for interpretation, e.g., using logarithms and imputation. And the the model specific transformation (e.g. numeric encodings).

Leave a Reply

Your email address will not be published. Required fields are marked *