SHAP + XGBoost + Tidymodels = LOVE

In this recent post, we have explained how to use Kernel SHAP for interpreting complex linear models. As plotting backend, we used our fresh CRAN package “shapviz“.

“shapviz” has direct connectors to a couple of packages such as XGBoost, LightGBM, H2O, kernelshap, and more. Multiple times people asked me how to combine shapviz when the XGBoost model was fitted with Tidymodels. The workflow was not 100% clear to me as well, but the answer is actually very simple, thanks to Julia’s post where the plots were made with SHAPforxgboost, another cool package for visualization of SHAP values.

Example with shiny diamonds

Step 1: Preprocessing

We first write the data preprocessing recipe and apply it to the data rows that we want to explain. In our case, its 1000 randomly sampled diamonds.

library(tidyverse)
library(tidymodels)
library(shapviz)

# Integer encode factors
dia_recipe <- diamonds %>%
  recipe(price ~ carat + cut + clarity + color) %>% 
  step_integer(all_nominal())

# Will explain THIS dataset later
set.seed(2)
dia_small <- diamonds[sample(nrow(diamonds), 1000), ]
dia_small_prep <- bake(
  prep(dia_recipe), 
  has_role("predictor"),
  new_data = dia_small, 
  composition = "matrix"
)
head(dia_small_prep)

#     carat cut clarity color
#[1,]  0.57   5       4     4
#[2,]  1.01   5       2     1
#[3,]  0.45   1       4     3
#[4,]  1.04   4       6     5
#[5,]  0.90   3       6     4
#[6,]  1.20   3       4     6

Step 2: Fit Model

The next step is to tune and build the model. For simplicity, we skipped the tuning part. Bad, bad 🙂

# Just for illustration - in practice needs tuning!
xgboost_model <- boost_tree(
  mode = "regression",
  trees = 200,
  tree_depth = 5,
  learn_rate = 0.05,
  engine = "xgboost"
)

dia_wf <- workflow() %>%
  add_recipe(dia_recipe) %>%
  add_model(xgboost_model)

fit <- dia_wf %>%
  fit(diamonds)

Step 3: SHAP Analysis

We now need to call shapviz() on the fitted model. In order to have neat interpretations with the original factor labels, we not only pass the prediction data prepared in Step 1 via bake(), but also the original data structure.

shap <- shapviz(extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small)

sv_importance(shap, kind = "both", show_numbers = TRUE)
sv_dependence(shap, "carat", color_var = "auto")
sv_dependence(shap, "clarity", color_var = "auto")
sv_force(shap, row_id = 1)
sv_waterfall(shap, row_id = 1)
Variable importance plot overlaid with SHAP summary beeswarms
Dependence plot for carat. Note that clarity is shown with original labels, not only integers.
Dependence plot for clarity. Note again that the x-scale uses the original factor levels, not the integer encoded values.
Force plot of the first observation
Waterfall plot for the first observation

Summary

Making SHAP analyses with XGBoost Tidymodels is super easy.

The complete R script can be found here.


Posted

in

, ,

by

Tags:

Comments

3 responses to “SHAP + XGBoost + Tidymodels = LOVE”

  1. […] SHAP + XGBoost + Tidymodels = LOVE […]

  2. Antros Louca Avatar
    Antros Louca

    Thanks for a very nice post.
    I was wondering how this solution might be used after one hot encoding categorical variables in recipes ?

    1. Michael Mayer Avatar

      While it does not make sense here, it is certainly possible. If you want to interpret the results on the original factor levels of color (instead of 7 dummies), you can use the `collapse` logic as follows (the formatting of the code is suboptimal):

      # Inspired by https://juliasilge.com/blog/board-games/

      # The recipe
      dia_recipe <- diamonds %>%
      recipe(price ~ carat + cut + clarity + color) %>%
      step_dummy(color, one_hot = TRUE) %>%
      step_integer(clarity, cut)

      # Rest identical up to call to shapviz()

      shap <- shapviz( extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small, collapse = list(color = paste0("color_", 1:nlevels(diamonds$color))) ) sv_dependence(shap, "color", color_var = "auto")

Leave a Reply

Your email address will not be published. Required fields are marked *