Author: Michael Mayer

  • Effect Plots in Python and R

    Christian and me did some code magic: Highly effective plots that help to build and inspect any model:

    The functionality is best described by its output:

    Python
    R

    The plots show different types of feature effects relevant in modeling:

    • Average observed: Descriptive effect (also interesting without model).
    • Average predicted: Combined effect of all features. Also called “M Plot” (Apley 2020).
    • Partial dependence: Effect of one feature, keeping other feature values constant (Friedman 2001).
    • Number of observations or sum of case weights: Feature value distribution.
    • R only: Accumulated local effects, an alternative to partial dependence (Apley 2020).

    Both implementations…

    • are highly efficient thanks to {Polars} in Python and {collapse} in R, and work on datasets with millions of observations,
    • support case weights with all their statistics, ideal in insurance applications,
    • calculate average residuals (not shown in the plots above),
    • provide standard deviations/errors of average observed and bias,
    • allow to switch to Plotly for interactive plots, and
    • are highly customizable (the R package, e.g., allows to collapse rare levels after calculating statistics via the update() method or to sort the features according to main effect importance).

    In the spirit of our “Lost In Translation” series, we provide both high-quality Python and R code. We will use the same data and models as in one of our latest posts on how to build strong GLMs via ML + XAI.

    Example

    Let’s build a Poisson LightGBM model to explain the claim frequency given six traditional features in a pricing dataset on motor liability claims. 80% of the 1 Mio rows are used for training, the other 20% for evaluation. Hyper-parameters have been slightly tuned (not shown).

    library(OpenML)
    library(lightgbm)
    
    dim(df <- getOMLDataSet(data.id = 45106L)$data)  # 1000000 7
    head(df)
    
    #   year town driver_age car_weight car_power car_age claim_nb
    # 0 2018    1         51       1760       173       3        0
    # 1 2019    1         41       1760       248       2        0
    # 2 2018    1         25       1240       111       2        0
    # 3 2019    0         40       1010        83       9        0
    # 4 2018    0         43       2180       169       5        0
    # 5 2018    1         45       1170       149       1        1
    
    yvar <- "claim_nb"
    xvars <- setdiff(colnames(df), yvar)
    
    ix <- 1:800000
    train <- df[ix, ]
    test <- df[-ix, ]
    X_train <- data.matrix(train[xvars])
    X_test <- data.matrix(test[xvars])
    
    # Training, using slightly optimized parameters found via cross-validation
    params <- list(
      learning_rate = 0.05,
      objective = "poisson",
      num_leaves = 7,
      min_data_in_leaf = 50,
      min_sum_hessian_in_leaf = 0.001,
      colsample_bynode = 0.8,
      bagging_fraction = 0.8,
      lambda_l1 = 3,
      lambda_l2 = 5,
      num_threads = 7
    )
    
    set.seed(1)
    
    fit <- lgb.train(
      params = params,
      data = lgb.Dataset(X_train, label = train$claim_nb),
      nrounds = 300
    )
    import matplotlib.pyplot as plt
    from lightgbm import LGBMRegressor
    from sklearn.datasets import fetch_openml
    
    df = fetch_openml(data_id=45106, parser="pandas").frame
    df.head()
    
    #   year town driver_age car_weight car_power car_age claim_nb
    # 0 2018    1         51       1760       173       3        0
    # 1 2019    1         41       1760       248       2        0
    # 2 2018    1         25       1240       111       2        0
    # 3 2019    0         40       1010        83       9        0
    # 4 2018    0         43       2180       169       5        0
    # 5 2018    1         45       1170       149       1        1
    
    # Train model on 80% of the data
    y = df.pop("claim_nb")
    n_train = 800_000
    X_train, y_train = df.iloc[:n_train], y.iloc[:n_train]
    X_test, y_test = df.iloc[n_train:], y.iloc[n_train:]
    
    params = {
        "learning_rate": 0.05,
        "objective": "poisson",
        "num_leaves": 7,
        "min_child_samples": 50,
        "min_child_weight": 0.001,
        "colsample_bynode": 0.8,
        "subsample": 0.8,
        "reg_alpha": 3,
        "reg_lambda": 5,
        "verbose": -1,
    }
    
    model = LGBMRegressor(n_estimators=300, **params, random_state=1)
    model.fit(X_train, y_train)

    Let’s inspect the (main effects) of the model on the test data.

    library(effectplots)
    
    # 0.3 s
    feature_effects(fit, v = xvars, data = X_test, y = test$claim_nb) |>
      plot(share_y = "all")
    from model_diagnostics.calibration import plot_marginal
    
    fig, axes = plt.subplots(3, 2, figsize=(8, 8), sharey=True, layout="tight")
    
    # 2.3 s
    for i, (feat, ax) in enumerate(zip(X_test.columns, axes.flatten())):
        plot_marginal(
            y_obs=y_test,
            y_pred=model.predict(X_test),
            X=X_test,
            feature_name=feat,
            predict_function=model.predict,
            ax=ax,
        )
        ax.set_title(feat)
        if i != 1:
            ax.legend().remove()

    The output can be seen at the beginning of this blog post.

    Here some model insights:

    • Average predictions closely match observed frequencies. No clear bias is visible.
    • Partial dependence shows that the year and the car weight almost have no impact (regarding their main effects), while the driver_age and car_power effects seem strongest. The shared y axes help to assess these.
    • Except for car_weight, the partial dependence curve closely follows the average predictions. This means that the model effect seems to really come from the feature on the x axis, and not of some correlated other feature (as, e.g., with car_weight which is actually strongly correlated with car_power).

    Final words

    • Inspecting models has become much relaxed with above functions.
    • The packages used offer much more functionality. Try them out! Or we will show them in later posts ;).

    References

    1. Apley, Daniel W., and Jingyu Zhu. 2020. Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models. Journal of the Royal Statistical Society Series B: Statistical Methodology, 82 (4): 1059–1086. doi:10.1111/rssb.12377.
    2. Friedman, Jerome H. 2001. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics 29 (5): 1189–1232. doi:10.1214/aos/1013203451.

    R script , Python notebook

  • Explaining a Causal Forest

    We use a causal forest [1] to model the treatment effect in a randomized controlled clinical trial. Then, we explain this black-box model with usual explainability tools. These will reveal segments where the treatment works better or worse, just like a forest plot, but multivariately.

    Data

    For illustration, we use patient-level data of a 2-arm trial of rectal indomethacin against placebo to prevent post-ERCP pancreatitis (602 patients) [2]. The dataset is available in the package {medicaldata}.

    The data is in fantastic shape, so we don’t need to spend a lot of time with data preparation.

    1. We integer encode factors.
    2. We select meaningful features, basically those shown in the forest plot of [2] (Figure 4) without low-information features and without hospital.

    The marginal estimate of the treatment effect is -0.078, i.e., indomethacin reduces the probability of post-ERCP pancreatitis by 7.8 percentage points. Our aim is to develop and interpret a model to see if this value is associated with certain covariates.

    library(medicaldata)
    suppressPackageStartupMessages(library(dplyr))
    library(grf)          #  causal_forest()
    library(ggplot2)
    library(patchwork)    #  Combine ggplots
    library(hstats)       #  Friedman's H, PDP
    library(kernelshap)   #  General SHAP
    library(shapviz)      #  SHAP plots
    
    W <- as.integer(indo_rct$rx) - 1L      # 0=placebo, 1=treatment
    table(W)
    #   0   1 
    # 307 295
    
    Y <- as.numeric(indo_rct$outcome) - 1  # Y=1: post-ERCP pancreatitis (bad)
    mean(Y)  # 0.1312292
    
    mean(Y[W == 1]) - mean(Y[W == 0])      # -0.07785568
    
    xvars <- c(
      "age",         # Age in years
      "male",        # Male (1=yes)
      "pep",         # Previous post-ERCP pancreatitis (1=yes)
      "recpanc",     # History of recurrent Pancreatitis (1=yes)
      "type",        # Sphincter of oddi dysfunction type/level (0=no, to 3=type 3)
      "difcan",      # Cannulation of the papilla was difficult (1=yes)
      "psphinc",     # Pancreatic sphincterotomy performed (1=yes)
      "bsphinc",     # Biliary sphincterotomy performed (1=yes)
      "pdstent",     # Pancreatic stent (1=yes)
      "train"        # Trainee involved in stenting (1=yes)
    )
    
    X <- indo_rct |>
      mutate_if(is.factor, function(v) as.integer(v) - 1L) |> 
      rename(male = gender) |> 
      select_at(xvars)
    
    head(X)
                
    # age  male   pep recpanc  type difcan psphinc bsphinc pdstent train
    #  26     0     0       1     1      0       0       0       0     1
    #  24     1     1       0     0      0       0       1       0     0
    #  57     0     0       0     2      0       0       0       0     0
    #  29     0     0       0     1      0       0       1       1     1
    #  38     0     1       0     1      0       1       1       1     1
    #  59     0     0       0     1      1       0       1       1     0
                
    summary(X)
                
    #     age             male             pep            recpanc     
    # Min.   :19.00   Min.   :0.0000   Min.   :0.0000   Min.   :0.000  
    # 1st Qu.:35.00   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.000  
    # Median :45.00   Median :0.0000   Median :0.0000   Median :0.000  
    # Mean   :45.27   Mean   :0.2093   Mean   :0.1595   Mean   :0.299  
    # 3rd Qu.:54.00   3rd Qu.:0.0000   3rd Qu.:0.0000   3rd Qu.:1.000  
    # Max.   :90.00   Max.   :1.0000   Max.   :1.0000   Max.   :1.000  
    #      type           difcan          psphinc          bsphinc      
    # Min.   :0.000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
    # 1st Qu.:1.000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000  
    # Median :2.000   Median :0.0000   Median :1.0000   Median :1.0000  
    # Mean   :1.743   Mean   :0.2608   Mean   :0.5698   Mean   :0.5714  
    # 3rd Qu.:2.000   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:1.0000  
    # Max.   :3.000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
    #    pdstent           train       
    # Min.   :0.0000   Min.   :0.0000  
    # 1st Qu.:1.0000   1st Qu.:0.0000  
    # Median :1.0000   Median :0.0000  
    # Mean   :0.8239   Mean   :0.4701  
    # 3rd Qu.:1.0000   3rd Qu.:1.0000  
    # Max.   :1.0000   Max.   :1.0000  

    The model

    We use the {grf} package to fit a causal forest [1], a tree-ensemble trying to estimate conditional average treatment effects (CATE) E[Y(1) – Y(0) | X = x]. As such, it can be used to study treatment effect inhomogeneity.

    In contrast to a typical random forest:

    • Honest trees are grown: Within trees, part of the data is used for splitting, and the other part for calculating the node values. This anti-overfitting is implemented for all random forests in {grf}.
    • Splits are selected to produce child nodes with maximally different treatment effects (under some additional constraints).

    Note: With about 13%, the complication rate is relatively low. Thus, the treatment effect (measured on absolute scale) can become small for certain segments simply because the complication rate is close to 0. Ideally, we could model relative treatment effects or odds ratios, but I have not found this option in {grf} so far.

    fit <- causal_forest(
      X = X,
      Y = Y,
      W = W,
      num.trees = 1000,
      mtry = 4,
      sample.fraction = 0.7,
      seed = 1,
      ci.group.size = 1,
    )

    Explain the model with “classic” techniques

    After looking at tree split importance, we study the effects via partial dependence plots and Friedman’s H. These only require a predict() function and a reference dataset.

    imp <- sort(setNames(variable_importance(fit), xvars))
    par(mai = c(0.7, 2, 0.2, 0.2))
    barplot(imp, horiz = TRUE, las = 1, col = "orange")
    
    pred_fun <- function(object, newdata, ...) {
      predict(object, newdata, ...)$predictions
    }
    
    pdps <- lapply(xvars, function(v) plot(partial_dep(fit, v, X = X, pred_fun = pred_fun)))
    wrap_plots(pdps, guides = "collect", ncol = 3) &
      ylim(c(-0.11, -0.06)) &
      ylab("Treatment effect")
                   
    H <- hstats(fit, X = X, pred_fun = pred_fun, verbose = FALSE)
    plot(H)
    
    partial_dep(fit, v = "age", X = X, BY = "bsphinc", pred_fun = pred_fun) |> 
      plot()

    Variable importance

    Variable importance of the causal forest can be measured by the relative counts each feature had been used to split on (in the first 4 levels). The most important variable is age.

    Main effects

    To study the main effects on the CATE, we consider partial dependence plots (PDP). Such plot shows how the average prediction depends on the values of a feature, keeping all other feature values constant (can be unnatural.)

    We can see that the treatment effect is strongest for persons up to age 35, then reduces until 45. For older patients, the effect increases again.

    Remember: Negative values mean a stronger (positive) treatment effect.

    Interaction strength

    Between what covariates are there strong interactions?

    A model agnostic way to assess pairwise interaction strength is Friedman’s H statistic [3]. It measures the error when approximating the two-dimensional partial dependence function of the two features by their univariate partial dependence functions. A value of zero means there is no interaction. A value of α means that about 100α% of the joint effect (variability) comes from the interaction.

    This measure is shown on the right hand side of the plot. More than 15% of the joint effect variability of age and biliary sphincterotomy (bsphinc) comes from their interaction.

    Typically, pairwise H-statistics are calculated only for the most important variables or those with high overall interaction strength. Overall interaction strength (left hand side of the plot) can be measured by a version of Friedman’s H. It shows how much of the prediction variability comes from interactions with that feature.

    Visualize strong interaction

    Interactions can be visualized, e.g., by a stratified PDP. We can see that the treatment effect is associated with age mainly for persons with biliary sphincterotomy.

    SHAP Analysis

    A “modern” way to explain the model is based on SHAP [4]. It decomposes the (centered) predictions into additive contributions of the covariates.

    Because there is no TreeSHAP shipped with {grf}, we use the much slower Kernel SHAP algorithm implemented in {kernelshap} that works for any model.

    First, we explain the prediction of a single data row, then we decompose many predictions. These decompositions can be analysed by simple descriptive plots to gain insights about the model as a whole.

    # Explaining one CATE
    kernelshap(fit, X = X[1, ], bg_X = X, pred_fun = pred_fun) |> 
      shapviz() |> 
      sv_waterfall() +
      xlab("Prediction")
    
    # Explaining all CATEs globally
    system.time(  # 13 min
      ks <- kernelshap(fit, X = X, pred_fun = pred_fun)  
    )
    shap_values <- shapviz(ks)
    
    sv_importance(shap_values)
    sv_importance(shap_values, kind = "bee")
    sv_dependence(shap_values, v = xvars) +
      plot_layout(ncol = 3) &
      ylim(c(-0.04, 0.03))

    Explain one CATE

    Explaining the CATE corresponding to the feature values of the first patient via waterfall plot.

    SHAP importance plot

    The bars show average absolute SHAP values. For instance, we can say that biliary sphincterotomy impacts the treatment effect on average by more than +- 0.01 (but we don’t see how).

    SHAP summary plot

    One-dimensional plot of SHAP values with scaled feature values on the color scale, sorted in the same order as the SHAP importance plot. Compared to the SHAP importance barplot, for instance, we can additionally see that biliary sphincterotomy weakens the treatment effect (positive SHAP value).

    SHAP dependence plots

    Scatterplots of SHAP values against corresponding feature values. Vertical scatter (at given x value) indicates presence of interactions. A candidate of an interacting feature is selected on the color scale. For instance, we see a similar pattern in the age effect on the treatment effect as in the partial dependence plot. Thanks to the color scale, we also see that the age effect depends on biliary sphincterotomy.

    Remember that SHAP values are on centered prediction scale. Still, a positive value means a weaker treatment effect.

    Wrap-up

    • {grf} is a fantastic package. You can expect more on it here.
    • Causal forests are an interesting way to directly model treatment effects.
    • Standard explainability methods can be used to explain the black-box.

    References

    1. Athey, Susan, Julie Tibshirani, and Stefan Wager. “Generalized Random Forests”. Annals of Statistics, 47(2), 2019.
    2. Elmunzer BJ et al. A randomized trial of rectal indomethacin to prevent post-ERCP pancreatitis. N Engl J Med. 2012 Apr 12;366(15):1414-22. doi: 10.1056/NEJMoa1111103.
    3. Friedman, Jerome H., and Bogdan E. Popescu. Predictive Learning via Rule Ensembles. The Annals of Applied Statistics 2, no. 3 (2008): 916-54.
    4. Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30 (2017).

    The full R notebook

  • Out-of-sample Imputation with {missRanger}

    {missRanger} is a multivariate imputation algorithm based on random forests, and a fast version of the original missForest algorithm of Stekhoven and Buehlmann (2012). Surprise, surprise: it uses {ranger} to fit random forests. Especially combined with predictive mean matching (PMM), the imputations are often quite realistic.

    Out-of-sample application

    The newest CRAN release 2.6.0 offers out-of-sample application. This is useful for removing any leakage between train/test data or during cross-validation. Furthermore, it allows to fill missing values in user provided data. By default, it uses the same number of PMM donors as during training, but you can change this by setting pmm.k = nice value.

    We distinguish two types of observations to be imputed:

    1. Easy case: Only a single value is missing. Here, we simply apply the corresponding random forest to fill the one missing value.
    2. Hard case: Multiple values are missing. Here, we first fill the values univariately, and then repeatedly apply the corresponding random forests, with the hope that the effect of univariate imputation vanishes. If values of two highly correlated features are missing, then the imputations can be non-sensical. There is no way to mend this.

    Example

    To illustrate the technique with a simple example, we use the iris data.

    1. First, we randomly add 10% missing values.
    2. Then, we make a train/test split.
    3. Next, we “fit” missRanger() to the training data.
    4. Finally, we use its new predict() method to fill the test data.

    library(missRanger)
    
    # 10% missings
    ir <- iris |> 
      generateNA(p = 0.1, seed = 11)
    
    # Train/test split stratified by Species
    oos <- c(1:10, 51:60, 101:110)
    train <- ir[-oos, ]
    test <- ir[oos, ]
    
    head(test)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3          NA  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4          NA          1.7          NA  setosa
    
    mr <- missRanger(train, pmm.k = 5, keep_forests = TRUE, seed = 1)
    test_filled <- predict(mr, test, seed = 1)
    head(test_filled)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3         0.2  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4         4.0          1.7         0.4  setosa
    
    # Original
    head(iris)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3         0.2  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4         3.9          1.7         0.4  setosa
    

    The results look reasonable, in this case even for the “hard case” row 6 with missing values in two variables. Here, it is probably the strong association with Species that helped to create good values.

    The new predict() also works with single row input.

    Learn more about {missRanger}

    The full R script

  • SHAP Values of Additive Models

    Within only a few years, SHAP (Shapley additive explanations) has emerged as the number 1 way to investigate black-box models. The basic idea is to decompose model predictions into additive contributions of the features in a fair way. Studying decompositions of many predictions allows to derive global properties of the model.

    What happens if we apply SHAP algorithms to additive models? Why would this ever make sense?

    In the spirit of our “Lost In Translation” series, we provide both high-quality Python and R code.

    The models

    Let’s build the models using a dataset with three highly correlated covariates and a (deterministic) response.

    library(lightgbm)
    library(kernelshap)
    library(shapviz)
    
    #===================================================================
    # Make small data
    #===================================================================
    
    make_data <- function(n = 100) {
      x1 <- seq(0.01, 1, length = n)
      data.frame(
        x1 = x1,
        x2 = log(x1),
        x3 = x1 > 0.7
      ) |>
        transform(y = 1 + 0.2 * x1 + 0.5 * x2 + x3 + sin(2 * pi * x1))
    }
    df <- make_data()
    head(df)
    cor(df) |>
      round(2)
    
    #      x1   x2   x3    y
    # x1 1.00 0.90 0.80 0.46
    # x2 0.90 1.00 0.58 0.58
    # x3 0.80 0.58 1.00 0.51
    # y  0.46 0.58 0.51 1.00
    
    #===================================================================
    # Additive linear model and additive boosted trees
    #===================================================================
    
    # Linear regression
    fit_lm <- lm(y ~ poly(x1, 3) + poly(x2, 3) + x3, data = df)
    summary(fit_lm)
    
    # Boosted trees
    xvars <- setdiff(colnames(df), "y")
    X <- data.matrix(df[xvars])
    
    params <- list(
      learning_rate = 0.05,
      objective = "mse",
      max_depth = 1,
      colsample_bynode = 0.7
    )
    
    fit_lgb <- lgb.train(
      params = params,
      data = lgb.Dataset(X, label = df$y),
      nrounds = 300
    )
    import numpy as np
    import lightgbm as lgb
    import shap
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.compose import ColumnTransformer
    from sklearn.pipeline import Pipeline
    from sklearn.linear_model import LinearRegression
    
    #===================================================================
    # Make small data
    #===================================================================
    
    def make_data(n=100):
        x1 = np.linspace(0.01, 1, n)
        x2 = np.log(x1)
        x3 = x1 > 0.7
        X = np.column_stack((x1, x2, x3))
    
        y = 1 + 0.2 * x1 + 0.5 * x2 + x3 + np.sin(2 * np.pi * x1)
        
        return X, y
    
    X, y = make_data()
    
    #===================================================================
    # Additive linear model and additive boosted trees
    #===================================================================
    
    # Linear model with polynomial terms
    poly = PolynomialFeatures(degree=3, include_bias=False)
    
    preprocessor =  ColumnTransformer(
        transformers=[
            ("poly0", poly, [0]),
            ("poly1", poly, [1]),
            ("other", "passthrough", [2]),
        ]
    )
    
    model_lm = Pipeline(
        steps=[
            ("preprocessor", preprocessor),
            ("lm", LinearRegression()),
        ]
    )
    _ = model_lm.fit(X, y)
    
    # Boosted trees with single-split trees
    params = dict(
        learning_rate=0.05,
        objective="mse",
        max_depth=1,
        colsample_bynode=0.7,
    )
    
    model_lgb = lgb.train(
        params=params,
        train_set=lgb.Dataset(X, label=y),
        num_boost_round=300,
    )

    SHAP

    For both models, we use exact permutation SHAP and exact Kernel SHAP. Furthermore, the linear model is analyzed with “additive SHAP”, and the tree-based model with TreeSHAP.

    Do the algorithms provide the same?

    system.time({  # 1s
      shap_lm <- list(
        add = shapviz(additive_shap(fit_lm, df)),
        kern = kernelshap(fit_lm, X = df[xvars], bg_X = df),
        perm = permshap(fit_lm, X = df[xvars], bg_X = df)
      )
    
      shap_lgb <- list(
        tree = shapviz(fit_lgb, X),
        kern = kernelshap(fit_lgb, X = X, bg_X = X),
        perm = permshap(fit_lgb, X = X, bg_X = X)
      )
    })
    
    # Consistent SHAP values for linear regression
    all.equal(shap_lm$add$S, shap_lm$perm$S)
    all.equal(shap_lm$kern$S, shap_lm$perm$S)
    
    # Consistent SHAP values for boosted trees
    all.equal(shap_lgb$lgb_tree$S, shap_lgb$lgb_perm$S)
    all.equal(shap_lgb$lgb_kern$S, shap_lgb$lgb_perm$S)
    
    # Linear coefficient of x3 equals slope of SHAP values
    tail(coef(fit_lm), 1)                # 1.112096
    diff(range(shap_lm$kern$S[, "x3"]))  # 1.112096
    
    sv_dependence(shap_lm$add, xvars)sv_dependence(shap_lm$add, xvars, color_var = NULL)
    shap_lm = {
        "add": shap.Explainer(model_lm.predict, masker=X, algorithm="additive")(X),
        "perm": shap.Explainer(model_lm.predict, masker=X, algorithm="exact")(X),
        "kern": shap.KernelExplainer(model_lm.predict, data=X).shap_values(X),
    }
    
    shap_lgb = {
        "tree": shap.Explainer(model_lgb)(X),
        "perm": shap.Explainer(model_lgb.predict, masker=X, algorithm="exact")(X),
        "kern": shap.KernelExplainer(model_lgb.predict, data=X).shap_values(X),
    }
    
    # Consistency for additive linear regression
    eps = 1e-12
    assert np.abs(shap_lm["add"].values - shap_lm["perm"].values).max() < eps
    assert np.abs(shap_lm["perm"].values - shap_lm["kern"]).max() < eps
    
    # Consistency for additive boosted trees
    assert np.abs(shap_lgb["tree"].values - shap_lgb["perm"].values).max() < eps
    assert np.abs(shap_lgb["perm"].values - shap_lgb["kern"]).max() < eps
    
    # Linear effect of last feature in the fitted model
    model_lm.named_steps["lm"].coef_[-1]  # 1.112096
    
    # Linear effect of last feature derived from SHAP values (ignore the sign)
    shap_lm["perm"][:, 2].values.ptp()    # 1.112096
    
    shap.plots.scatter(shap_lm["add"])
    SHAP dependence plot of the additive linear model and the additive explainer (Python).

    Yes – the three algorithms within model provide the same SHAP values. Furthermore, the SHAP values reconstruct the additive components of the features.

    Didactically, this is very helpful when introducing SHAP as a method: Pick a white-box and a black-box model and compare their SHAP dependence plots. For the white-box model, you simply see the additive components, while the dependence plots of the black-box model show scatter due to interactions.

    Remark: The exact equivalence between algorithms is lost, when

    • there are too many features for exact procedures (~10+ features), and/or when
    • the background data of Kernel/Permutation SHAP does not agree with the training data. This leads to slightly different estimates of the baseline value, which itself influences the calculation of SHAP values.

    Final words

    • SHAP algorithms applied to additive models typically give identical results. Slight differences might occur because sampling versions of the algos are used, or a different baseline value is estimated.
    • The resulting SHAP values describe the additive components.
    • Didactically, it helps to see SHAP analyses of white-box and black-box models side by side.

    R script , Python notebook

  • Building Strong GLMs in Python via ML + XAI

    In our latest post, we explained how to use ML + XAI to build strong generalized linear models with R. Let’s do the same with Python.

    Insurance pricing data

    We will use again a synthetic dataset with 1 Mio insurance policies, with reference:

    Mayer, M., Meier, D. and Wuthrich, M.V. (2023),
    SHAP for Actuaries: Explain any Model.
    https://doi.org/10.2139/ssrn.4389797

    Let’s start by loading and describing the data:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import shap
    from sklearn.datasets import fetch_openml
    from sklearn.inspection import PartialDependenceDisplay
    from sklearn.metrics import mean_poisson_deviance
    from sklearn.dummy import DummyRegressor
    from lightgbm import LGBMRegressor
    # We need preview version of glum that adds formulaic API
    # !pip install git+https://github.com/Quantco/glum@glum-v3#egg=glum
    from glum import GeneralizedLinearRegressor
    
    # Load data
    
    df = fetch_openml(data_id=45106, parser="pandas").frame
    df.head()
    
    # Continuous features
    df.hist(["driver_age", "car_weight", "car_power", "car_age"])
    _ = plt.suptitle("Histograms of continuous features", fontsize=15)
    
    # Response and discrete features
    fig, axes = plt.subplots(figsize=(8, 3), ncols=3)
    
    for v, ax in zip(["claim_nb", "year", "town"], axes):
        df[v].value_counts(sort=False).sort_index().plot(kind="bar", ax=ax, rot=0, title=v)
    plt.suptitle("Barplots of response and discrete features", fontsize=15)
    plt.tight_layout()
    plt.show()
    
    # Rank correlations
    corr = df.corr("spearman")
    mask = np.triu(np.ones_like(corr, dtype=bool))
    plt.suptitle("Rank-correlogram", fontsize=15)
    _ = sns.heatmap(
        corr, mask=mask, vmin=-0.7, vmax=0.7, center=0, cmap="vlag", square=True
    )
    
    

    Modeling

    1. We fit a tuned Boosted Trees model to model log(E(claim count)) via Poisson deviance loss.
    2. And perform a SHAP analysis to derive insights.
    from sklearn.model_selection import train_test_split
    
    X_train, X_test, y_train, y_test = train_test_split(
        df.drop("claim_nb", axis=1), df["claim_nb"], test_size=0.1, random_state=30
    )
    
    # Tuning step not shown. Number of boosting rounds found via early stopping on CV performance
    params = dict(
        learning_rate=0.05,
        objective="poisson",
        num_leaves=7,
        min_child_samples=50,
        min_child_weight=0.001,
        colsample_bynode=0.8,
        subsample=0.8,
        reg_alpha=3,
        reg_lambda=5,
        verbose=-1,
    )
    
    model_lgb = LGBMRegressor(n_estimators=360, **params)
    model_lgb.fit(X_train, y_train)
    
    # SHAP analysis
    X_explain = X_train.sample(n=2000, random_state=937)
    explainer = shap.Explainer(model_lgb)
    shap_val = explainer(X_explain)
    
    plt.suptitle("SHAP importance", fontsize=15)
    shap.plots.bar(shap_val)
    
    for s in [shap_val[:, 0:3], shap_val[:, 3:]]:
        shap.plots.scatter(s, color=shap_val, ymin=-0.5, ymax=1)

    Here, we would come to the conclusions:

    1. car_weight and year might be dropped, depending on the specify aim of the model.
    2. Add a regression spline for driver_age.
    3. Add an interaction between car_power and town.

    Build strong GLM

    Let’s build a GLM with these insights. Two important things:

    1. Glum is an extremely powerful GLM implementation that was inspired by a pull request of our Christian Lorentzen.
    2. In the upcoming version 3.0, it adds a formula API based of formulaic, a very performant formula parser. This gives a very easy way to add interaction effects, regression splines, dummy encodings etc.
    model_glm = GeneralizedLinearRegressor(
        family="poisson",
        l1_ratio=1.0,
        alpha=1e-10,
        formula="car_power * C(town) + bs(driver_age, 7) + car_age",
    )
    model_glm.fit(X_train, y=y_train)  # 1 second on old laptop
    
    # PDPs of both models
    fig, ax = plt.subplots(2, 2, figsize=(7, 5))
    cols = ("tab:blue", "tab:orange")
    for color, name, model in zip(cols, ("GLM", "LGB"), (model_glm, model_lgb)):
        disp = PartialDependenceDisplay.from_estimator(
            model,
            features=["driver_age", "car_age", "car_power", "town"],
            X=X_explain,
            ax=ax if name == "GLM" else disp.axes_,
            line_kw={"label": name, "color": color},
        )
    fig.suptitle("PDPs of both models", fontsize=15)
    fig.tight_layout()
    
    # Stratified PDP of car_power
    for color, town in zip(("tab:blue", "tab:orange"), (0, 1)):
        mask = X_explain.town == town
        disp = PartialDependenceDisplay.from_estimator(
            model_glm,
            features=["car_power"],
            X=X_explain[mask],
            ax=None if town == 0 else disp.axes_,
            line_kw={"label": town, "color": color},
        )
    plt.suptitle("PDP of car_power stratified by town (0 vs 1)", fontsize=15)
    _ = plt.ylim(0, 0.2)

    In this relatively simple situation, the mean Poisson deviance of our models are very simlar now:

    model_dummy = DummyRegressor().fit(X_train, y=y_train)
    deviance_null = mean_poisson_deviance(y_test, model_dummy.predict(X_test)) 
    
    dev_imp = []
    for name, model in zip(("GLM", "LGB", "Null"), (model_glm, model_lgb, model_dummy)):
        dev_imp.append((name, mean_poisson_deviance(y_test, model.predict(X_test))))
    pd.DataFrame(dev_imp, columns=["Model", "Mean_Poisson_Deviance"])

    Final words

    • Glum is an extremely powerful GLM implementation – we have only scratched its surface. You can expect more blogposts on Glum…
    • Having a formula interface is especially useful for adding interactions. Fingers crossed that the upcoming version 3.0 will soon be released.
    • Building GLMs via ML + XAI is so smooth, especially when you work with large data. For small data, you need to be careful to not add hidden overfitting to the model.

    Click here for the full Python notebook

  • ML + XAI -> Strong GLM

    My last post was using {hstats}, {kernelshap} and {shapviz} to explain a binary classification random forest. Here, we use the same package combo to improve a Poisson GLM with insights from a boosted trees model.

    Insurance pricing data

    This time, we work with a synthetic, but quite realistic dataset. It describes 1 Mio insurance policies and their corresponding claim counts. A reference for the data is:

    Mayer, M., Meier, D. and Wuthrich, M.V. (2023),
    SHAP for Actuaries: Explain any Model.
    http://dx.doi.org/10.2139/ssrn.4389797

    library(OpenML)
    library(lightgbm)
    library(splines)
    library(ggplot2)
    library(patchwork)
    library(hstats)
    library(kernelshap)
    library(shapviz)
    
    #===================================================================
    # Load and describe data
    #===================================================================
    
    df <- getOMLDataSet(data.id = 45106)$data
    
    dim(df)  # 1000000       7
    head(df)
    
    # year town driver_age car_weight car_power car_age claim_nb
    # 2018    1         51       1760       173       3        0
    # 2019    1         41       1760       248       2        0
    # 2018    1         25       1240       111       2        0
    # 2019    0         40       1010        83       9        0
    # 2018    0         43       2180       169       5        0
    # 2018    1         45       1170       149       1        1
    
    summary(df)
    
    # Response
    ggplot(df, aes(claim_nb)) +
      geom_bar(fill = "chartreuse4") +
      ggtitle("Distribution of the response")
    
    # Features
    xvars <- c("year", "town", "driver_age", "car_weight", "car_power", "car_age")
    
    df[xvars] |> 
      stack() |> 
    ggplot(aes(values)) +
      geom_histogram(fill = "chartreuse4", bins = 19) +
      facet_wrap(~ind, scales = "free", ncol = 2) +
      ggtitle("Distribution of the features")
    
    # car_power and car_weight are correlated 0.68, car_age and driver_age 0.28
    df[xvars] |> 
      cor() |> 
      round(2)
    #            year  town driver_age car_weight car_power car_age
    # year          1  0.00       0.00       0.00      0.00    0.00
    # town          0  1.00      -0.16       0.00      0.00    0.00
    # driver_age    0 -0.16       1.00       0.09      0.10    0.28
    # car_weight    0  0.00       0.09       1.00      0.68    0.00
    # car_power     0  0.00       0.10       0.68      1.00    0.09
    # car_age       0  0.00       0.28       0.00      0.09    1.00
    
    

    Modeling

    1. We fit a naive additive linear GLM and a tuned Boosted Trees model.
    2. We combine the models and specify their predict function.
    # Train/test split
    set.seed(8300)
    ix <- sample(nrow(df), 0.9 * nrow(df))
    train <- df[ix, ]
    valid <- df[-ix, ]
    
    # Naive additive linear Poisson regression model
    (fit_glm <- glm(claim_nb ~ ., data = train, family = poisson()))
    
    # Boosted trees with LightGBM. The parameters (incl. number of rounds) have been 
    # by combining early-stopping with random search CV (not shown here)
    
    dtrain <- lgb.Dataset(data.matrix(train[xvars]), label = train$claim_nb)
    
    params <- list(
      learning_rate = 0.05, 
      objective = "poisson", 
      num_leaves = 7, 
      min_data_in_leaf = 50, 
      min_sum_hessian_in_leaf = 0.001, 
      colsample_bynode = 0.8, 
      bagging_fraction = 0.8, 
      lambda_l1 = 3, 
      lambda_l2 = 5
    )
    
    fit_lgb <- lgb.train(params = params, data = dtrain, nrounds = 300)  
    
    # {hstats} works for multi-output predictions,
    # so we can combine all models to a list, which simplifies the XAI part.
    models <- list(GLM = fit_glm, LGB = fit_lgb)
    
    # Custom predictions on response scale
    pf <- function(m, X) {
      cbind(
        GLM = predict(m$GLM, X, type = "response"),
        LGB = predict(m$LGB, data.matrix(X[xvars]))
      )
    }
    pf(models, head(valid, 2))
    #       GLM        LGB
    # 0.1082285 0.08580529
    # 0.1071895 0.09181466
    
    # And on log scale
    pf_log <- function(m, X) {
      log(pf(m = m, X = X))
    }
    pf_log(models, head(valid, 2))
    #       GLM       LGB
    # -2.223510 -2.455675
    # -2.233157 -2.387983 -2.346350

    Traditional XAI

    Performance

    Comparing average Poisson deviance on the validation data shows that the LGB model is clearly better than the naively built GLM, so there is room for improvent!

    perf <- average_loss(
      models, X = valid, y = "claim_nb", loss = "poisson", pred_fun = pf
    )
    perf
    #       GLM       LGB 
    # 0.4362407 0.4331857
    

    Feature importance

    Next, we calculate permutation importance on the validation data with respect to mean Poisson deviance loss. The results make sense, and we note that year and car_weight seem to be negligile.

    imp <- perm_importance(
      models, v = xvars, X = valid, y = "claim_nb", loss = "poisson", pred_fun = pf
    )
    plot(imp)

    Main effects

    Next, we visualize estimated main effects by partial dependence plots on log link scale. The differences between the models are quite small, with one big exception: Investing more parameters into driver_age via spline will greatly improve the performance and usefulness of the GLM.

    partial_dep(models, v = "driver_age", train, pred_fun = pf_log) |> 
      plot(show_points = FALSE)
    
    pdp <- function(v) {
      partial_dep(models, v = v, X = train, pred_fun = pf_log) |> 
        plot(show_points = FALSE)
    }
    wrap_plots(lapply(xvars, pdp), guides = "collect") &
      ylim(-2.8, -1.7)

    Interaction effects

    Friedman’s H-squared (per feature and feature pair) and on log link scale shows that – unsurprisingly – our GLM does not contain interactions, and that the strongest relative interaction happens between town and car_power. The stratified PDP visualizes this interaction. Let’s add a corresponding interaction effect to our GLM later.

    system.time(  # 5 sec
      H <- hstats(models, v = xvars, X = train, pred_fun = pf_log)
    )
    H
    plot(H)
    
    # Visualize strongest interaction by stratified PDP
    partial_dep(models, v = "car_power", X = train, pred_fun = pf_log, BY = "town") |> 
      plot(show_points = FALSE)

    SHAP

    As an elegant alternative to studying feature importance, PDPs and Friedman’s H, we can simply run a SHAP analysis on the LGB model.

    set.seed(22)
    X_explain <- train[sample(nrow(train), 1000), xvars]
     
    shap_values_lgb <- shapviz(fit_lgb, data.matrix(X_explain))
    sv_importance(shap_values_lgb)
    sv_dependence(shap_values_lgb, v = xvars) &
      ylim(-0.35, 0.8)

    Here, we would come to the same conclusions:

    1. car_weight and year might be dropped.
    2. Add a regression spline for driver_age
    3. Add an interaction between car_power and town.

    Pimp the GLM

    In the final section, we apply the three insights from above with very good results.

    fit_glm2 <- glm(
      claim_nb ~ car_power * town + ns(driver_age, df = 7) + car_age, 
      data = train, 
      family = poisson()
      
    # Performance now as good as LGB
    perf_glm2 <- average_loss(
      fit_glm2, X = valid, y = "claim_nb", loss = "poisson", type = "response"
    )
    perf_glm2  # 0.432962
    
    # Effects similar as LGB, and smooth
    partial_dep(fit_glm2, v = "driver_age", X = train) |> 
      plot(show_points = FALSE)
    
    partial_dep(fit_glm2, v = "car_power", X = train, BY = "town") |> 
      plot(show_points = FALSE)

    Or even via permutation or kernel SHAP:

    set.seed(1)
    bg <- train[sample(nrow(train), 200), ]
    xvars2 <- setdiff(xvars, c("year", "car_weight"))
    
    system.time(  # 4 sec
      ks_glm2 <- permshap(fit_glm2, X = X_explain[xvars2], bg_X = bg)
    )
    shap_values_glm2 <- shapviz(ks_glm2)
    sv_dependence(shap_values_glm2, v = xvars2) &
      ylim(-0.3, 0.8)

    Final words

    • Improving naive GLMs with insights from ML + XAI is fun.
    • In practice, the gap between GLM and a boosted trees model can’t be closed that easily. (The true model behind our synthetic dataset contains a single interaction, unlike real data/models that typically have much more interactions.)
    • {hstats} can work with multiple regression models in parallel. This helps to keep the workflow smooth. Similar for {kernelshap}.
    • A SHAP analysis often brings the same qualitative insights as multiple other XAI tools together.

    The full R script

  • Explain that tidymodels blackbox!

    Let’s explain a {tidymodels} random forest by classic explainability methods (permutation importance, partial dependence plots (PDP), Friedman’s H statistics), and also fancy SHAP.

    Disclaimer: {hstats}, {kernelshap} and {shapviz} are three of my own packages.

    Diabetes data

    We will use the diabetes prediction dataset of Kaggle to model diabetes (yes/no) as a function of six demographic features (age, gender, BMI, hypertension, heart disease, and smoking history). It has 100k rows.

    Note: The data additionally contains the typical diabetes indicators HbA1c level and blood glucose level, but we wont use them to avoid potential causality issues, and to gain insights also for people that do not know these values.

    # https://www.kaggle.com/datasets/iammustafatz/diabetes-prediction-dataset
    
    library(tidyverse)
    library(tidymodels)
    library(hstats)
    library(kernelshap)
    library(shapviz)
    library(patchwork)
    
    df0 <- read.csv("diabetes_prediction_dataset.csv")  # from above Kaggle link
    dim(df0)  # 100000 9
    head(df0)
    # gender age hypertension heart_disease smoking_history   bmi HbA1c_level blood_glucose_level diabetes
    # Female  80            0             1           never 25.19         6.6                 140        0
    # Female  54            0             0         No Info 27.32         6.6                  80        0
    #   Male  28            0             0           never 27.32         5.7                 158        0
    # Female  36            0             0         current 23.45         5.0                 155        0
    #   Male  76            1             1         current 20.14         4.8                 155        0
    # Female  20            0             0           never 27.32         6.6                  85        0
    
    summary(df0)
    anyNA(df0)  # FALSE
    table(df0$smoking_history, useNA = "ifany")
    
    # DATA PREPARATION
    
    # Note: tidymodels needs a factor response for classification
    df1 <- df0 |>
      transform(
        y = factor(diabetes, levels = 0:1, labels = c("No", "Yes")),
        female = (gender == "Female") * 1,
        smoking_history = factor(
          smoking_history, 
          levels = c("No Info", "never", "former", "not current", "current", "ever")
        ),
        bmi = pmin(bmi, 50)
      )
    
    # UNIVARIATE ANALYSIS
    
    ggplot(df1, aes(diabetes)) +
      geom_bar(fill = "chartreuse4")
    
    df1  |>  
      select(age, bmi, HbA1c_level, blood_glucose_level) |> 
      pivot_longer(everything()) |> 
      ggplot(aes(value)) +
      geom_histogram(fill = "chartreuse4", bins = 19) +
      facet_wrap(~ name, scale = "free_x")
    
    ggplot(df1, aes(smoking_history)) +
      geom_bar(fill = "chartreuse4")
    
    df1 |> 
      select(heart_disease, hypertension, female) |>
      pivot_longer(everything()) |> 
      ggplot(aes(name, value)) +
      stat_summary(fun = mean, geom = "bar", fill = "chartreuse4") +
      xlab(element_blank())
    
    “yes” proportion of binary variables (including the response)
    Distribution of numeric variables
    Distribution of smoking_history

    Modeling

    Let’s fit a random forest via tidymodels with {ranger} backend.

    We add a predict function pf() that outputs only the probability of the “Yes” class.

    set.seed(1)
    ix <- initial_split(df1, strata = diabetes, prop = 0.8)
    train <- training(ix)
    test <- testing(ix)
    
    xvars <- c("age", "bmi", "smoking_history", "heart_disease", "hypertension", "female")
    
    rf_spec <- rand_forest(trees = 500) |> 
      set_mode("classification") |> 
      set_engine("ranger", num.threads = NULL, seed = 49)
    
    rf_wf <- workflow() |> 
      add_model(rf_spec) |>
      add_formula(reformulate(xvars, "y"))
    
    model <- rf_wf |> 
        fit(train)
    
    # predict() gives No/Yes columns
    predict(model, head(test), type = "prob")
    # .pred_No .pred_Yes
    #    0.981    0.0185
    
    # We need to extract only the "Yes" probabilities
    pf <- function(m, X) {
      predict(m, X, type = "prob")$.pred_Yes
    }
    pf(model, head(test))  # 0.01854290 ...
    

    Classic explanation methods

    # 4 times repeated permutation importance wrt test logloss
    imp <- perm_importance(
      model, X = test, y = "diabetes", v = xvars, pred_fun = pf, loss = "logloss"
    )
    plot(imp) +
      xlab("Increase in test logloss")
    
    # Partial dependence of age
    partial_dep(model, v = "age", train, pred_fun = pf) |> 
      plot()
    
    # All PDP in one patchwork
    p <- lapply(xvars, function(x) plot(partial_dep(model, v = x, X = train, pred_fun = pf)))
    wrap_plots(p) &
      ylim(0, 0.23) &
      ylab("Probability")
    
    # Friedman's H stats
    system.time( # 20 s
      H <- hstats(model, train[xvars], approx = TRUE, pred_fun = pf)
    )
    H  # 15% of prediction variability comes from interactions
    plot(H)
    
    # Stratified PDP of strongest interaction
    partial_dep(model, "age", BY = "bmi", X = train, pred_fun = pf) |> 
      plot(show_points = FALSE)

    Feature importance

    Permutation importance measures by how much the average test loss (in our case log loss) increases when a feature is shuffled before calculating the losses. We repeat the process four times and also show standard errors.

    Permutation importance: Age and BMI are the two main risk factors.

    Main effects

    Main effects are estimated by PDP. They show how the average prediction changes with a feature, keeping every other feature fixed. Using a fixed vertical axis helps to grasp the strenght of the effect.

    PDPs: The diabetes risk tends to increase with age, high (and very low) BMI, presence of heart disease/hypertension, and it is a bit lower for females and non-smoker.

    Interaction strength

    Interaction strength can be measured by Friedman’s H statistics, see the earlier blog post. A specific interaction can then be visualized by a stratified PDP.

    Friedman’s H statistics: Left: BMI and age are the two features with clearly strongest interactions. Right: Their pairwise interaction explains about 10% of their joint effect variability.
    Stratified PDP: The strong interaction between age and BMI is clearly visible. A high BMI makes the age effect on diabetes stronger.

    SHAP

    What insights does a SHAP analysis bring?

    We will crunch slow exact permutation SHAP values via kernelshap::permshap(). If we had more features, we could switch to

    • kernelshap::kernelshap()
    • Brandon Greenwell’s {fastshap}, or to the
    • {treeshap} package of my colleages from TU Warsaw.
    set.seed(1)
    X_explain <- train[sample(1:nrow(train), 1000), xvars]
    X_background <- train[sample(1:nrow(train), 200), ]
    
    system.time(  # 10 minutes
      shap_values <- permshap(model, X = X_explain, bg_X = X_background, pred_fun = pf)
    )
    shap_values <- shapviz(shap_values)
    shap_values  # 'shapviz' object representing 1000 x 6 SHAP matrix
    saveRDS(shap_values, file = "shap_values.rds")
    # shap_values <- readRDS("shap_values.rds")
    
    sv_importance(shap_values, show_numbers = TRUE)
    sv_importance(shap_values, kind = "bee")
    sv_dependence(shap_values, v = xvars) &
      ylim(-0.14, 0.24) &
      ylab("Probability")

    SHAP importance

    SHAP importance: On average, the age increases or decreases the diabetes probability by 4.7% etc. In this case, the top three features are the same as in permutation importance.

    SHAP “summary” plot

    SHAP “summary” plot: Additionally to the bar plot, we see that higher age, higher BMI, hypertension, smoking, males, and having a heart disease are associated with higher diabetes risk.

    SHAP dependence plots

    SHAP dependence plots: We see similar shapes as in the PDPs. Thanks to the vertical scatter, we can, e.g., spot that the BMI effect strongly depends on the age. As in the PDPs, we have selected a common vertical scale to also see the effect strength.

    Final words

    • {hstats}, {kernelshap} and {shapviz} can explain any model with XAI methods like permutation importance, PDPs, Friedman’s H, and SHAP. This, obviously, also includes models developed with {tidymodels}.
    • They would actually even work for multi-output models, e.g., classification with more than two categories.
    • Studying a blackbox with XAI methods is always worth the effort, even if the methods have their issues. I.e., an imperfect explanation is still better than no explanation.
    • Model-agnostic SHAP takes a little bit of time, but it is usually worth the effort.

    The full R script

  • Permutation SHAP versus Kernel SHAP

    SHAP is the predominant way to interpret black-box ML models, especially for tree-based models with the blazingly fast TreeSHAP algorithm.

    For general models, two slower SHAP algorithms exist:

    1. Permutation SHAP (Štrumbelj and Kononenko, 2010)
    2. Kernel SHAP (Lundberg and Lee, 2017)

    Kernel SHAP was introduced in [2] as an approximation to permutation SHAP.

    The 0.4.0 CRAN release of our {kernelshap} package now contains an exact permutation SHAP algorithm for up to 14 features, and thus it becomes easy to make experiments between the two approaches.

    Some initial statements about permutation SHAP and Kernel SHAP

    1. Exact permutation SHAP and exact Kernel SHAP have the same computational complexity.
    2. Technically, exact Kernel SHAP is still an approximation of exact permutation SHAP, so you should prefer the latter.
    3. Kernel SHAP assumes feature independence. Since features are never independent in practice: does this mean we should never use Kernel SHAP?
    4. Kernel SHAP can be calculated almost exactly for any number of features, while permutation SHAP approximations get more and more inprecise when the number of features gets too large.

    Simulation 1

    We will first work with the iris data because it has extremely strong correlations between features. To see the impact of having models with and without interactions, we work with a random forest model of increasing tree depth. Depth 1 means no interactions, depth 2 means pairwise interactions etc.

    library(kernelshap)
    library(ranger)
    
    differences <- numeric(4)
    
    set.seed(1)
    
    for (depth in 1:4) {
      fit <- ranger(
        Sepal.Length ~ ., 
        mtry = 3,
        data = iris, 
        max.depth = depth
      )
      ps <- permshap(fit, iris[2:5], bg_X = iris)
      ks <- kernelshap(fit, iris[2:5], bg_X = iris)
      differences[depth] <- mean(abs(ks$S - ps$S))
    }
    
    differences  # for tree depth 1, 2, 3, 4
    # 5.053249e-17 9.046443e-17 2.387905e-04 4.403375e-04
    
    # SHAP values of first two rows with tree depth 4
    ps
    #      Sepal.Width Petal.Length Petal.Width      Species
    # [1,]  0.11377616   -0.7130647  -0.1956012 -0.004437022
    # [2,] -0.06852539   -0.7596562  -0.2259017 -0.006575266
      
    ks
    #      Sepal.Width Petal.Length Petal.Width      Species
    # [1,]  0.11463191   -0.7125194  -0.1951810 -0.006258208
    # [2,] -0.06828866   -0.7597391  -0.2259833 -0.006647530
    • Up to pairwise interactions (tree depth 2), the mean absolute difference between the two (150 x 4) SHAP matrices is 0.
    • Even for interactions of order three or higher, the differences are small. This is unexpected – in the end all iris features are strongly correlated!

    Simulation 2

    Let’s now use a different data set with more features: miami house price data. As modeling technique, we use XGBoost where we would normally use TreeSHAP. Also here, we increase tree depth from 1 to 3 for increasing interaction depth.

    library(xgboost)
    library(shapviz)
    
    colnames(miami) <- tolower(colnames(miami))
    miami$log_ocean <- log(miami$ocean_dist)
    x <- c("log_ocean", "tot_lvg_area", "lnd_sqfoot", "structure_quality", "age", "month_sold")
    
    # Train/valid split
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    
    y_train <- log(miami$sale_prc[ix])
    y_valid <- log(miami$sale_prc[-ix])
    X_train <- data.matrix(miami[ix, x])
    X_valid <- data.matrix(miami[-ix, x])
    
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    # Fit via early stopping (depth 1 to 3)
    differences <- numeric(3)
    
    for (i in 1:3) {
      fit <- xgb.train(
        params = list(learning_rate = 0.15, objective = "reg:squarederror", max_depth = i),
        data = dtrain,
        watchlist = list(valid = dvalid),
        early_stopping_rounds = 20,
        nrounds = 1000,
        callbacks = list(cb.print.evaluation(period = 100))
      )
      ps <- permshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500))
      ks <- kernelshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500))
      differences[i] <- mean(abs(ks$S - ps$S))
    }
    differences # for tree depth 1, 2, 3
    # 2.904010e-09 5.158383e-09 6.586577e-04
    
    # SHAP values of top two rows for tree depth 3
    ps
    # log_ocean tot_lvg_area lnd_sqfoot structure_quality        age  month_sold
    # 0.2224359   0.04941044  0.1266136         0.1360166 0.01036866 0.005557032
    # 0.3674484   0.01045079  0.1192187         0.1180312 0.01426247 0.005465283
    
    ks
    # log_ocean tot_lvg_area lnd_sqfoot structure_quality        age  month_sold
    # 0.2245202  0.049520308  0.1266020         0.1349770 0.01142703 0.003355770
    # 0.3697167  0.009575195  0.1198201         0.1168738 0.01544061 0.003450425

    Again the same picture as with iris: Essentially no differences for interactions up to order two, and only small differences with interactions of higher order.

    Wrap-Up

    1. Use kernelshap::permshap() to crunch exact permutation SHAP values for models with not too many features.
    2. In real-world applications, exact Kernel SHAP and exact permutation SHAP start to differ (slightly) with models containing interactions of order three or higher.
    3. Since Kernel SHAP can be calculated almost exactly also for many features, it remains an excellent way to crunch SHAP values for arbitrary models.

    What is your experience?

    The R code is here.

  • Interactions – where are you?

    This question sends shivers down the poor modelers spine…

    The {hstats} R package introduced in our last post measures their strength using Friedman’s H-statistics, a collection of statistics based on partial dependence functions.

    On Github, the preview version of {hstats} 1.0.0 out – I will try to bring it to CRAN in about one week (October 2023). Until then, try it via devtools::install_github("mayer79/hstats")

    The current version offers:

    • H statistics per feature, feature pair, and feature triple
    • Multivariate predictions at no additional cost
    • A convenient API
    • Other important tools from explainable ML:
      • performance calculations
      • permutation importance (e.g., to select features for calculating H-statistics)
      • partial dependence plots (including grouping, multivariate, multivariable)
      • individual conditional expectations (ICE)
    • Case-weights are available for all methods, which is important, e.g., in insurance applications
    • The option for fast quantile approximation of H-statistics

    This post has two parts:

    1. Example with house-prices and XGBoost
    2. Naive benchmark against {iml}, {DALEX}, and my old {flashlight}.

    1. Example

    Let’s model logarithmic sales prices of houses sold in Miami Dade County, a dataset prepared by Prof. Dr. Steven Bourassa, and available in {shapviz}. We use XGBoost with interaction constraints to provide a model additive in all structure information, but allowing for interactions between latitude/longitude for a flexible representation of geographic effects.

    The following code prepares the data, splits the data into train and validation, and then fits an XGBoost model.

    library(hstats)
    library(shapviz)
    library(xgboost)
    library(ggplot2)
    
    # Data preparation
    colnames(miami) <- tolower(colnames(miami))
    miami <- transform(miami, log_price = log(sale_prc))
    x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
           "structure_quality", "age", "month_sold")
    coord <- c("longitude", "latitude")
    
    # Modeling
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    train <- data.frame(miami[ix, ])
    valid <- data.frame(miami[-ix, ])
    y_train <- train$log_price
    y_valid <- valid$log_price
    X_train <- data.matrix(train[x])
    X_valid <- data.matrix(valid[x])
    
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    ic <- c(
      list(which(x %in% coord) - 1),
      as.list(which(!x %in% coord) - 1)
    )
    
    # Fit via early stopping
    fit <- xgb.train(
      params = list(
        learning_rate = 0.15, 
        objective = "reg:squarederror", 
        max_depth = 5,
        interaction_constraints = ic
      ),
      data = dtrain,
      watchlist = list(valid = dvalid),
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )

    Now it is time for a compact analysis with {hstats} to interpret the model:

    average_loss(fit, X = X_valid, y = y_valid)  # 0.0247 MSE -> 0.157 RMSE
    
    perm_importance(fit, X = X_valid, y = y_valid) |> 
      plot()
    
    # Or combining some features
    v_groups <- list(
      coord = c("longitude", "latitude"),
      size = c("lnd_sqfoot", "tot_lvg_area"),
      condition = c("age", "structure_quality")
    )
    perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |> 
      plot()
    
    H <- hstats(fit, v = x, X = X_valid)
    H
    plot(H)
    plot(H, zero = FALSE)
    h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE)
    partial_dep(fit, v = "tot_lvg_area", X = X_valid) |> 
      plot()
    partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |> 
      plot(show_points = FALSE)
    plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid))
    plot(ii, center = TRUE)
    
    # Spatial plots
    g <- unique(X_valid[, coord])
    pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
    plot(pp, d2_geom = "point", alpha = 0.5, size = 1) + 
      coord_equal()
    
    # Takes some seconds because it generates the last plot per structure quality
    partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
      plot(pp, d2_geom = "point", alpha = 0.5) +
      coord_equal()
    )

    Results summarized by plots

    Permutation importance

    Figure 1: Permutation importance (4 repetitions) on the validation data. Error bars show standard errors of the estimated increase in MSE from shuffling feature values.
    Figure 2: Feature groups can be shuffled together – accounting for issues of permutation importance with highly correlated features

    H-Statistics

    Let’s now move on to interaction statistics.

    Figure 3: Overall and pairwise H-statistics. Overall H^2 gives the proportion of prediction variability explained by all interactions of the feature. By default, {hstats} picks the five features with largest H^2 and calculates their pairwise H^2. This explains why not all 21 feature pairs appear in the figure on the right-hand side. Pairwise H^2 is differently scaled than overall H^2: It gives the proportion of joint effect variability of the two features explained by their interaction.
    Figure 4: Use “zero = FALSE” to drop variable (pairs) with value 0.

    PDPs and ICEs

    Figure 5: A partial dependence plot of living area.
    Figure 6: Stratification shows indeed: no interactions between structure quality and living area.
    Figure 7: ICE plots also show no interations with any other feature. The interaction constraints of XGBoost did a good job.
    Figure 8: This two-dimensional PDP evaluated over all unique coordinates shows a realistic profile of house prices in Miami Dade County (mind the log scale).
    Figure 8: Same, but grouped by structure quality (5 is best). Since there is no interaction between location and structure quality, the plots are just shifted versions of each other. (You can’t really see it on the plots.)

    Naive Benchmark

    All methods in {hstats} are optimized for speed. But how fast are they compared to other implementations? Note that: this is just a simple benchmark run on a Windows notebook with Intel i7-8650U CPU.

    Note that {iml} offers a parallel backend, but we could not make it run with XGBoost and Windows. Let me know how fast it is using parallelism and Linux!

    Setup + benchmark on permutation importance

    Always using the full validation dataset and 10 repetitions.

    library(iml)  # Might benefit of multiprocessing, but on Windows with XGB models, this is not easy
    library(DALEX)
    library(ingredients)
    library(flashlight)
    library(bench)
    
    set.seed(1)
    
    # iml
    predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
    mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid, 
                         predict.function = predf)
    
    # DALEX
    ex <- DALEX::explain(fit, data = X_valid, y = y_valid)
    
    # flashlight (my slightly old fashioned package)
    fl <- flashlight(
      model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm"
    )
      
    # Permutation importance: 10 repeats over full validation data (~2700 rows)
    bench::mark(
      iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"),
      dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf),
      flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10),
      hstats = perm_importance(fit, X = X_valid, y = y_valid, m_rep = 10, verbose = FALSE),
      check = FALSE,
      min_iterations = 3
    )
    
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.58s    1.58s     0.631   209.4MB    2.73      3    13      4.76s
    # dalex      566.21ms 586.91ms     1.72     34.6MB    0.572     3     1      1.75s
    # flashlight 587.03ms 613.15ms     1.63     27.1MB    1.63      3     3      1.84s
    # hstats     353.78ms 360.57ms     2.79     27.2MB    0         3     0      1.08s

    {hstats} is about 30% faster as the second, {DALEX}.

    Partial dependence

    Here, we study the time for crunching partial dependence of a continuous feature and a discrete feature.

    # Partial dependence (cont)
    v <- "tot_lvg_area"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"),
      dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50),
      hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.11s    1.13s     0.887   376.3MB     3.84     3    13      3.38s
    # dalex      782.13ms 783.08ms     1.24    192.8MB     2.90     3     7      2.41s
    # flashlight 367.73ms  372.5ms     2.68     67.9MB     2.68     3     3      1.12s
    # hstats     220.88ms  222.5ms     4.50     14.2MB     0        3     0   666.33ms
     
    # Partial dependence (discrete)
    v <- "structure_quality"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5),
      dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf),
      hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml            90ms     96ms     10.6    13.29MB     7.06     3     2      283ms
    # dalex       170.6ms  174.4ms      5.73   20.55MB     2.87     2     1      349ms
    # flashlight   40.8ms   43.8ms     23.1     6.36MB     2.10    11     1      476ms
    # hstats       23.5ms   24.4ms     40.6     1.53MB     2.14    19     1      468ms

    {hstats} is 1.5 to 2 times faster than {flashlight}, and about four times as fast as the other packages. It’s memory foodprint is much lower.

    H-statistics

    How fast can overall H-statistics be computed? How fast can it do pairwise calculations?

    {DALEX} does not offer these statistics yet. {iml} was the first model-agnostic implementation of H-statistics I am aware of. It uses quantile approximation by default, but we purposely force it to calculate exact, in order to compare the numbers. Thus, we made it slower than it actually is.

    # H-Stats -> we use a subset of 500 rows
    X_v500 <- X_valid[1:500, ]
    mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
    fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))
    
    # iml  # 225s total, using slow exact calculations
    system.time(  # 90s
      iml_overall <- Interaction$new(mod500, grid.size = 500)
    )
    system.time(  # 135s for all combinations of latitude
      iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
    )
    
    # flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
    system.time(  # 12s
      fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
    )
    system.time(  # 2s
      fl_pairwise <- light_interaction(
        fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
      )
    )
    
    # hstats: 3s total
    system.time({
      H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
      hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
      hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE)
    }
    )
    
    # Overall statistics correspond exactly
    iml_overall$results |> filter(.interaction > 1e-6)
    #     .feature .interaction
    # 1:  latitude    0.2458269
    # 2: longitude    0.2458269
    
    fl_overall$data |> subset(value > 0, select = c(variable, value))
    #   variable  value
    # 1 latitude  0.246
    # 2 longitude 0.246
    
    hstats_overall
    # longitude  latitude 
    # 0.2458269 0.2458269 
    
    # Pairwise results match as well
    iml_pairwise$results |> filter(.interaction > 1e-6)
    #              .feature .interaction
    # 1: longitude:latitude    0.3942526
    
    fl_pairwise$data |> subset(value > 0, select = c(variable, value))
    # latitude:longitude 0.394
    
    hstats_pairwise
    # latitude:longitude 
    # 0.3942526 
    • {hstats} is about four times as fast as {flashlight}.
    • Since one often want to study relative and absolute H-statistics, in practice, the speed-up would be about a factor of eight.
    • In multi-classification/multi-output settings with m categories, the speed-up would be even m times larger.
    • The fast approximation via quantile binning is again a factor of four faster. The difference would diminish if we would calculate many pairwise or three-way H-statistics.
    • Forcing all three packages to calculate exact statistics, all results match.

    Wrap-Up

    • {hstats} is much faster than other XAI packages, at least in our use-case. This includes H-statistics, permutation importance, and partial dependence. Note that making good benchmarks is not my strength, so forgive any bias in the results.
    • The memory foodprint is lower as well.
    • With multivariate output, the potential is even larger.
    • H-Statistics match other implementations.

    Try it out!

    The full R code in one piece is here.

  • It’s the interactions

    What makes a ML model a black-box? It is the interactions. Without any interactions, the ML model is additive and can be exactly described.

    Studying interaction effects of ML models is challenging. The main XAI approaches are:

    1. Looking at ICE plots, stratified PDP, and/or 2D PDP.
    2. Study vertical scatter in SHAP dependence plots, or even consider SHAP interaction values.
    3. Check partial-dependence based H-statistics introduced in Friedman and Popescu (2008), or related statistics.

    This post is mainly about the third approach. Its beauty is that we get information about all interactions. The downside: it is as good/bad as partial dependence functions. And: the statistics are computationally very expensive to compute (of order n^2).

    Different R packages offer some of these H-statistics, including {iml}, {gbm}, {flashlight}, and {vivid}. They all have their limitations. This is why I wrote the new R package {hstats}:

    • It is very efficient.
    • Has a clean API. DALEX explainers and meta-learners (mlr3, Tidymodels, caret) work out-of-the-box.
    • Supports multivariate predictions, including classification models.
    • Allows to calculate unnormalized H-statistics. They help to compare pairwise and three-way statistics.
    • Contains fast multivariate ICE/PDPs with optional grouping variable.

    In Python, there is the very interesting project artemis. I will write a post on it later.

    Statistics supported by {hstats}

    Furthermore, a global measure of non-additivity (proportion of prediction variability unexplained by main effects), and a measure of feature importance is available. For technical details and references, check the following pdf or github.

    Classification example

    Let’s fit a probability random forest on iris species.

    library(ranger)
    library(ggplot2)
    library(hstats)
    
    v <- setdiff(colnames(iris), "Species")
    fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1)
    s <- hstats(fit, v = v, X = iris)  # 8 seconds run-time
    s
    # Proportion of prediction variability unexplained by main effects of v:
    #      setosa  versicolor   virginica 
    # 0.002705945 0.065629375 0.046742035
    
    plot(s, normalize = FALSE, squared = FALSE) +
      ggtitle("Unnormalized statistics") +
      scale_fill_viridis_d(begin = 0.1, end = 0.9)
    
    ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> 
      plot(center = TRUE) +
      ggtitle("Centered ICE plots")
    
    Unnormalized H-statistics, i.e., values are roughly on the scale of the predictions (here: probabilities).
    Centered ICE plots per class.

    Interpretation:

    • The features with strongest interactions are Petal Length and Petal Width. These interactions mainly affect species “virginica” and “versicolor”. The effect for “setosa” is almost additive.
    • Unnormalized pairwise statistics show that the strongest absolute interaction happens indeed between Petal Length and Petal Width.
    • The centered ICE plots shows how the interaction manifests: The effect of Petal Length heavily depends on Petal Width, except for species “setosa”. Would a SHAP analysis show the same?

    DALEX example

    Here, we consider a random forest regression on “Sepal.Length”.

    library(DALEX)
    library(ranger)
    library(hstats)
    
    set.seed(1)
    
    fit <- ranger(Sepal.Length ~ ., data = iris)
    ex <- explain(fit, data = iris[-1], y = iris[, 1])
    
    s <- hstats(ex)  # 2 seconds
    s  # Non-additivity index 0.054
    plot(s)
    plot(ice(ex, v = "Sepal.Width", BY = "Petal.Width"), center = TRUE)
    
    H-statistics
    Centered ICE plot of strongest relative interactions.

    Interpretation

    • Petal Length and Width show the strongest overall associations. Since we are considering normalized statistics, we can say: “About 3.5% of prediction variability comes from interactions with Petal Length”.
    • The strongest relative pairwise interaction happens between Sepal Width and Petal Width: Again, because we study normalized H-statistics, we can say: “About 4% of total prediction variability of the two features Sepal Width and Petal Width can be attributed to their interactions.”
    • Overall, all interactions explain only about 5% of prediction variability (see text output).

    Try it out!

    The complete R script can be found here. More examples and background can be found on the Github page of the project.

  • Geographic SHAP

    Lost in Translation between R and Python 10

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    This post is heavily based on the new {shapviz} vignette.

    Setting

    Besides other features, a model with geographic components contains features like

    • latitude and longitude,
    • postal code, and/or
    • other features that depend on location, e.g., distance to next restaurant.

    Like any feature, the effect of a single geographic feature can be described using SHAP dependence plots. However, studying the effect of latitude (or any other location dependent feature) alone is often not very illuminating – simply due to strong interaction effects and correlations with other geographic features.

    That’s where the additivity of SHAP values comes into play: The sum of SHAP values of all geographic components represent the total geographic effect, and this sum can be visualized as a heatmap or 3D scatterplot against latitude/longitude (or any other geographic representation).

    A first example

    For illustration, we will use a beautiful house price dataset containing information on about 14’000 houses sold in 2016 in Miami-Dade County. Some of the columns are as follows:

    • SALE_PRC: Sale price in USD: Its logarithm will be our model response.
    • LATITUDE, LONGITUDE: Coordinates
    • CNTR_DIST: Distance to central business district
    • OCEAN_DIST: Distance (ft) to the ocean
    • RAIL_DIST: Distance (ft) to the next railway track
    • HWY_DIST: Distance (ft) to next highway
    • TOT_LVG_AREA: Living area in square feet
    • LND_SQFOOT: Land area in square feet
    • structure_quality: Measure of building quality (1: worst to 5: best)
    • age: Age of the building in years

    (Italic features are geographic components.) For more background on this dataset, see Mayer et al [2].

    We will fit an XGBoost model to explain log(price) as a function of lat/long, size, and quality/age.

    devtools::install_github("ModelOriented/shapviz", dependencies = TRUE)
    library(xgboost)
    library(ggplot2)
    library(shapviz)  # Needs development version 0.9.0 from github
    
    head(miami)
    
    x_coord <- c("LATITUDE", "LONGITUDE")
    x_nongeo <- c("TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age")
    x <- c(x_coord, x_nongeo)
    
    # Train/valid split
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    X_train <- data.matrix(miami[ix, x])
    X_valid <- data.matrix(miami[-ix, x])
    y_train <- log(miami$SALE_PRC[ix])
    y_valid <- log(miami$SALE_PRC[-ix])
    
    # Fit XGBoost model with early stopping
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    params <- list(learning_rate = 0.2, objective = "reg:squarederror", max_depth = 5)
    
    fit <- xgb.train(
      params = params, 
      data = dtrain, 
      watchlist = list(valid = dvalid), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    %load_ext lab_black
    
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import fetch_openml
    
    df = fetch_openml(data_id=43093, as_frame=True)
    X, y = df.data, np.log(df.target)
    X.head()
    
    # Data split and model
    from sklearn.model_selection import train_test_split
    import xgboost as xgb
    
    x_coord = ["LONGITUDE", "LATITUDE"]
    x_nongeo = ["TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age"]
    x = x_coord + x_nongeo
    
    X_train, X_valid, y_train, y_valid = train_test_split(
        X[x], y, test_size=0.2, random_state=30
    )
    
    # Fit XGBoost model with early stopping
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dvalid = xgb.DMatrix(X_valid, label=y_valid)
    
    params = dict(learning_rate=0.2, objective="reg:squarederror", max_depth=5)
    
    fit = xgb.train(
        params=params,
        dtrain=dtrain,
        evals=[(dvalid, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    

    SHAP dependence plots

    Let’s first study selected SHAP dependence plots, evaluated on the validation dataset with around 2800 observations. Note that we could as well use the training data for this purpose, but it is a bit large.

    sv <- shapviz(fit, X_pred = X_valid)
    sv_dependence(
      sv, 
      v = c("TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"), 
      alpha = 0.2
    )
    import shap
    
    xgb_explainer = shap.Explainer(fit)
    shap_values = xgb_explainer(X_valid)
    
    v = ["TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"]
    shap.plots.scatter(shap_values[:, v], color=shap_values[:, v])
    SHAP dependence plots of selected features (Python output).

    Total coordindate effect

    And now the two-dimensional plot of the sum of SHAP values:

    sv_dependence2D(sv, x = "LONGITUDE", y = "LATITUDE") +
      coord_equal()
    shap_coord = shap_values[:, x_coord]
    plt.scatter(*list(shap_coord.data.T), c=shap_coord.values.sum(axis=1), s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    Sum of SHAP values on color scale against coordinates (Python output).

    The last plot gives a good impression on price levels, but note:

    1. Since we have modeled logarithmic prices, the effects are on relative scale (0.1 means about 10% above average).
    2. Due to interaction effects with non-geographic components, the location effects might depend on features like living area. This is not visible in above plot. We will modify the model now to improve this aspect.

    Two modifications

    We will now change above model in two ways, not unlike the model in Mayer et al [2].

    1. We will use additional geographic features like distance to railway track or to the ocean.
    2. We will use interaction constraints to allow only interactions between geographic features.

    The second step leads to a model that is additive in each non-geographic component and also additive in the combined location effect. According to the technical report of Mayer [1], SHAP dependence plots of additive components in a boosted trees model are shifted versions of corresponding partial dependence plots (evaluated at observed values). This allows a “Ceteris Paribus” interpretation of SHAP dependence plots of corresponding components.

    # Extend the feature set
    more_geo <- c("CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST")
    x2 <- c(x, more_geo)
    
    X_train2 <- data.matrix(miami[ix, x2])
    X_valid2 <- data.matrix(miami[-ix, x2])
    
    dtrain2 <- xgb.DMatrix(X_train2, label = y_train)
    dvalid2 <- xgb.DMatrix(X_valid2, label = y_valid)
    
    # Build interaction constraint vector
    ic <- c(
      list(which(x2 %in% c(x_coord, more_geo)) - 1),
      as.list(which(x2 %in% x_nongeo) - 1)
    )
    
    # Modify parameters
    params$interaction_constraints <- ic
    
    fit2 <- xgb.train(
      params = params, 
      data = dtrain2, 
      watchlist = list(valid = dvalid2), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    
    # SHAP analysis
    sv2 <- shapviz(fit2, X_pred = X_valid2)
    
    # Two selected features: Thanks to additivity, structure_quality can be read as 
    # Ceteris Paribus
    sv_dependence(sv2, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
    
    # Total geographic effect (Ceteris Paribus thanks to additivity)
    sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
      coord_equal()
    # Extend the feature set
    more_geo = ["CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST"]
    x2 = x + more_geo
    
    X_train2, X_valid2 = train_test_split(X[x2], test_size=0.2, random_state=30)
    
    dtrain2 = xgb.DMatrix(X_train2, label=y_train)
    dvalid2 = xgb.DMatrix(X_valid2, label=y_valid)
    
    # Build interaction constraint vector
    ic = [x_coord + more_geo, *[[z] for z in x_nongeo]]
    
    # Modify parameters
    params["interaction_constraints"] = ic
    
    fit2 = xgb.train(
        params=params,
        dtrain=dtrain2,
        evals=[(dvalid2, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    
    # SHAP analysis
    xgb_explainer2 = shap.Explainer(fit2)
    shap_values2 = xgb_explainer2(X_valid2)
    
    v = ["structure_quality", "LONGITUDE"]
    shap.plots.scatter(shap_values2[:, v], color=shap_values2[:, v])
    
    # Total location effect
    shap_coord2 = shap_values2[:, x_coord]
    c = shap_values2[:, x_coord + more_geo].values.sum(axis=1)
    plt.scatter(*list(shap_coord2.data.T), c=c, s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    SHAP dependence plots of an additive feature (structure quality, no vertical scatter per unique feature value) and one of the geographic features (Python output).
    Sum of all geographic features (color) against coordinates. There are no interactions to non-geographic features, so the effect can be read Ceteris Paribus (Python output).

    Again, the resulting total geographic effect looks reasonable.

    Wrap-Up

    • SHAP values of all geographic components in a model can be summed up and plotted on the color scale against coordinates (or some other geographic representation). This gives a lightning fast impression of the location effects.
    • Interaction constraints between geographic and non-geographic features lead to Ceteris Paribus interpretation of total geographic effects.

    The Python and R notebooks can be found here:

    References

    1. Mayer, Michael. 2022. “SHAP for Additively Modeled Features in a Boosted Trees Model.” https://arxiv.org/abs/2207.14490.
    2. Mayer, Michael, Steven C. Bourassa, Martin Hoesli, and Donato Flavio Scognamiglio. 2022. “Machine Learning Applications to Land and Structure Valuation.” Journal of Risk and Financial Management.

  • SHAP + XGBoost + Tidymodels = LOVE

    In this recent post, we have explained how to use Kernel SHAP for interpreting complex linear models. As plotting backend, we used our fresh CRAN package “shapviz“.

    “shapviz” has direct connectors to a couple of packages such as XGBoost, LightGBM, H2O, kernelshap, and more. Multiple times people asked me how to combine shapviz when the XGBoost model was fitted with Tidymodels. The workflow was not 100% clear to me as well, but the answer is actually very simple, thanks to Julia’s post where the plots were made with SHAPforxgboost, another cool package for visualization of SHAP values.

    Example with shiny diamonds

    Step 1: Preprocessing

    We first write the data preprocessing recipe and apply it to the data rows that we want to explain. In our case, its 1000 randomly sampled diamonds.

    library(tidyverse)
    library(tidymodels)
    library(shapviz)
    
    # Integer encode factors
    dia_recipe <- diamonds %>%
      recipe(price ~ carat + cut + clarity + color) %>% 
      step_integer(all_nominal())
    
    # Will explain THIS dataset later
    set.seed(2)
    dia_small <- diamonds[sample(nrow(diamonds), 1000), ]
    dia_small_prep <- bake(
      prep(dia_recipe), 
      has_role("predictor"),
      new_data = dia_small, 
      composition = "matrix"
    )
    head(dia_small_prep)
    
    #     carat cut clarity color
    #[1,]  0.57   5       4     4
    #[2,]  1.01   5       2     1
    #[3,]  0.45   1       4     3
    #[4,]  1.04   4       6     5
    #[5,]  0.90   3       6     4
    #[6,]  1.20   3       4     6
    

    Step 2: Fit Model

    The next step is to tune and build the model. For simplicity, we skipped the tuning part. Bad, bad 🙂

    # Just for illustration - in practice needs tuning!
    xgboost_model <- boost_tree(
      mode = "regression",
      trees = 200,
      tree_depth = 5,
      learn_rate = 0.05,
      engine = "xgboost"
    )
    
    dia_wf <- workflow() %>%
      add_recipe(dia_recipe) %>%
      add_model(xgboost_model)
    
    fit <- dia_wf %>%
      fit(diamonds)

    Step 3: SHAP Analysis

    We now need to call shapviz() on the fitted model. In order to have neat interpretations with the original factor labels, we not only pass the prediction data prepared in Step 1 via bake(), but also the original data structure.

    shap <- shapviz(extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small)
    
    sv_importance(shap, kind = "both", show_numbers = TRUE)
    sv_dependence(shap, "carat", color_var = "auto")
    sv_dependence(shap, "clarity", color_var = "auto")
    sv_force(shap, row_id = 1)
    sv_waterfall(shap, row_id = 1)
    
    Variable importance plot overlaid with SHAP summary beeswarms
    Dependence plot for carat. Note that clarity is shown with original labels, not only integers.
    Dependence plot for clarity. Note again that the x-scale uses the original factor levels, not the integer encoded values.
    Force plot of the first observation
    Waterfall plot for the first observation

    Summary

    Making SHAP analyses with XGBoost Tidymodels is super easy.

    The complete R script can be found here.

  • Dplyr-style without dplyr

    One of the reasons why we love the “dplyr” package: it plays so well together with the forward pipe operator `%>%` from the “magrittr” package. Actually, it is not a coincidence that both packages were released quite at the same time, in 2014.

    What does the pipe do? It puts the object on its left as the first argument into the function on its right: iris %>% head() is a funny way of writing head(iris). It helps to avoid long function chains like f(g(h(x))), or repeated assignments.

    In 2021 and version 4.1, R has received its native forward pipe operator |> so that we can write nice code like this:

    Imagine this without pipe…

    Since version 4.2, the piped object can be referenced by the underscore _, but just once for now, see an example below.

    To use the native pipe via CTRL-SHIFT-M in Posit/RStudio, tick this:

    Combined with the many great functions from the standard distribution of R, we can get a real “dplyr” feeling without even loading dplyr. Don’t get me wrong: I am a huge fan of the whole Tidyverse! But it is a great way to learn “Standard R”.

    Data chains

    Here a small selection of standard functions playing well together with the pipe: They take a data frame and return a modified data frame:

    • subset(): Select rows and columns of data frame
    • transform(): Add or overwrite columns in data frame
    • aggregate(): Grouped calculations
    • rbind(), cbind(): Bind rows/columns of data frame/matrix
    • merge(): Join data frames by key
    • head(), tail(): First/last few elements of object
    • reshape(): Transposition/Reshaping of data frame (no, I don’t understand the interface)
    library(ggplot2)  # Need diamonds
    
    # What does the native pipe do?
    quote(diamonds |> head())
    
    # OUTPUT
    # head(diamonds)
    
    # Grouped statistics
    diamonds |> 
      aggregate(cbind(price, carat) ~ color, FUN = mean)
    
    # OUTPUT
    #   color    price     carat
    # 1     D 3169.954 0.6577948
    # 2     E 3076.752 0.6578667
    # 3     F 3724.886 0.7365385
    # 4     G 3999.136 0.7711902
    # 5     H 4486.669 0.9117991
    # 6     I 5091.875 1.0269273
    # 7     J 5323.818 1.1621368
    
    # Join back grouped stats to relevant columns
    diamonds |> 
      subset(select = c(price, color, carat)) |> 
      transform(price_per_color = ave(price, color)) |> 
      head()
    
    # OUTPUT
    #   price color carat price_per_color
    # 1   326     E  0.23        3076.752
    # 2   326     E  0.21        3076.752
    # 3   327     E  0.23        3076.752
    # 4   334     I  0.29        5091.875
    # 5   335     J  0.31        5323.818
    # 6   336     J  0.24        5323.818
    
    # Plot transformed values
    diamonds |> 
      transform(
        log_price = log(price),
        log_carat = log(carat)
      ) |> 
      plot(log_price ~ log_carat, col = "chartreuse4", pch = ".", data = _)
    A simple scatterplot

    The plot does not look quite as sexy as “ggplot2”, but its a start.

    Other chains

    The pipe not only works perfectly with functions that modify a data frame. It also shines with many other functions often applied in a nested way. Here two examples:

    # Distribution of color within clarity
    diamonds |> 
      subset(select = c(color, clarity)) |> 
      table() |> 
      prop.table(margin = 2) |> 
      addmargins(margin = 1) |> 
      round(3)
    
    # OUTPUT
    # clarity
    # color      I1   SI2   SI1   VS2   VS1  VVS2  VVS1    IF
    #     D   0.057 0.149 0.159 0.138 0.086 0.109 0.069 0.041
    #     E   0.138 0.186 0.186 0.202 0.157 0.196 0.179 0.088
    #     F   0.193 0.175 0.163 0.180 0.167 0.192 0.201 0.215
    #     G   0.202 0.168 0.151 0.191 0.263 0.285 0.273 0.380
    #     H   0.219 0.170 0.174 0.134 0.143 0.120 0.160 0.167
    #     I   0.124 0.099 0.109 0.095 0.118 0.072 0.097 0.080
    #     J   0.067 0.052 0.057 0.060 0.066 0.026 0.020 0.028
    #     Sum 1.000 1.000 1.000 1.000 1.000 1.000 1.000 1.000
    
    # Barplot from discrete column
    diamonds$color |> 
      table() |> 
      prop.table() |> 
      barplot(col = "chartreuse4", main = "Color")

    Wrap up

    • Piping is fun with and without dplyr.
    • It is a great motivation to learn standard R

    The complete R script can be found here.

  • Interpret Complex Linear Models with SHAP within Seconds

    A linear model with complex interaction effects can be almost as opaque as a typical black-box like XGBoost.

    XGBoost models are often interpreted with SHAP (Shapley Additive eXplanations): Each of e.g. 1000 randomly selected predictions is fairly decomposed into contributions of the features using the extremely fast TreeSHAP algorithm, providing a rich interpretation of the model as a whole. TreeSHAP was introduced in the Nature publication by Lundberg and Lee (2020).

    Can we do the same for non-tree-based models like a complex GLM or a neural network? Yes, but we have to resort to slower model-agnostic SHAP algorithms:

    In the limit, the two algorithms provide the same SHAP values.

    House prices

    We will use a great dataset with 14’000 house prices sold in Miami in 2016. The dataset was kindly provided by Prof. Steven Bourassa for research purposes and can be found on OpenML.

    The model

    We will model house prices by a Gamma regression with log-link. The model includes factors, linear components and natural cubic splines. The relationship of living area and distance to central district is modeled by letting the spline bases of the two features interact.

    library(OpenML)
    library(tidyverse)
    library(splines)
    library(doFuture)
    library(kernelshap)
    library(shapviz)
    
    raw <- OpenML::getOMLDataSet(43093)$data
    
    # Lump rare level 3 and log transform the land size
    prep <- raw %>%
      mutate(
        structure_quality = factor(structure_quality, labels = c(1, 2, 4, 4, 5)),
        log_landsize = log(LND_SQFOOT)
      )
    
    # 1) Build model
    xvars <- c("TOT_LVG_AREA", "log_landsize", "structure_quality",
               "CNTR_DIST", "age", "month_sold")
    
    fit <- glm(
      SALE_PRC ~ ns(log(CNTR_DIST), df = 4) * ns(log(TOT_LVG_AREA), df = 4) +
        log_landsize + structure_quality + ns(age, df = 4) + ns(month_sold, df = 4),
      family = Gamma("log"),
      data = prep
    )
    summary(fit)
    
    # Selected coefficients:
    # log_landsize: 0.22559  
    # structure_quality4: 0.63517305 
    # structure_quality5: 0.85360956   
    

    The model has 37 parameters. Some of the estimates are shown.

    Interpretation

    The workflow of a SHAP analysis is as follows:

    1. Sample 1000 rows to explain
    2. Sample 100 rows as background data used to estimate marginal expectations
    3. Calculate SHAP values. This can be done fully in parallel by looping over the rows selected in Step 1
    4. Analyze the SHAP values

    Step 2 is the only additional step compared with TreeSHAP. It is required both for SHAP sampling values and Kernel SHAP.

    # 1) Select rows to explain
    set.seed(1)
    X <- prep[sample(nrow(prep), 1000), xvars]
    
    # 2) Select small representative background data
    bg_X <- prep[sample(nrow(prep), 100), ]
    
    # 3) Calculate SHAP values in fully parallel mode
    registerDoFuture()
    plan(multisession, workers = 6)  # Windows
    # plan(multicore, workers = 6)   # Linux, macOS, Solaris
    
    system.time( # <10 seconds
      shap_values <- kernelshap(
        fit, X, bg_X = bg_X, parallel = T, parallel_args = list(.packages = "splines")
      )
    )

    Thanks to parallel processing and some implementation tricks, we were able to decompose 1000 predictions within 10 seconds! By default, kernelshap() uses exact calculations up to eight features (exact regarding the background data), which would need an infinite amount of Monte-Carlo-sampling steps.

    Note that glm() has a very efficient predict() function. GAMs, neural networks, random forests etc. usually take more time, e.g. 5 minutes to do the crunching.

    Analyze the SHAP values

    # 4) Analyze them
    sv <- shapviz(shap_values)
    
    sv_importance(sv, show_numbers = TRUE) +
      ggtitle("SHAP Feature Importance")
    
    sv_dependence(sv, "log_landsize")
    sv_dependence(sv, "structure_quality")
    sv_dependence(sv, "age")
    sv_dependence(sv, "month_sold")
    sv_dependence(sv, "TOT_LVG_AREA", color_var = "auto")
    sv_dependence(sv, "CNTR_DIST", color_var = "auto")
    
    # Slope of log_landsize: 0.2255946
    diff(sv$S[1:2, "log_landsize"]) / diff(sv$X[1:2, "log_landsize"])
    
    # Difference between structure quality 4 and 5: 0.2184365
    diff(sv$S[2:3, "structure_quality"])
    SHAP Importance: Living area and the distance to the central district are the two most important predictors. The month (within 2016) impacts the predicted prices by +-1.3% on average.
    SHAP dependence plot of “log_landsize”. The effect is linear. The slope 0.22559 agrees with the model coefficient.
    Dependence plot for “structure_quality”: The difference between structure quality 4 and 5 is 0.2184365. This equals the difference in regression coefficients.
    Dependence plot of “living_area”: The effect is very steep. The more central, the steeper. We cannot easily compare these numbers with the output of the linear regression.

    Summary

    • Interpreting complex linear models with SHAP is an option. There seems to be a correspondence between regression coefficients and SHAP dependence, at least for additive components.
    • Kernel SHAP in R is fast. For models with slower predict() functions (e.g. GAMs, random forests, or neural nets), we often need to wait a couple of minutes.

    The complete R script can be found here.

  • Kernel SHAP in R and Python

    Lost in Translation between R and Python 9

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    Kernel SHAP

    SHAP is one of the most used model interpretation technique in Machine Learning. It decomposes predictions into additive contributions of the features in a fair way. For tree-based methods, the fast TreeSHAP algorithm exists. For general models, one has to resort to computationally expensive Monte-Carlo sampling or the faster Kernel SHAP algorithm. Kernel SHAP uses a regression trick to get the SHAP values of an observation with a comparably small number of calls to the predict function of the model. Still, it is much slower than TreeSHAP.

    Two good references for Kernel SHAP:

    1. Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017.
    2. Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.

    In our last post, we introduced our new “kernelshap” package in R. Since then, the package has been substantially improved, also by the big help of David Watson:

    1. The package now supports multi-dimensional predictions.
    2. It received a massive speed-up
    3. Additionally, parallel computing can be activated for even faster calculations.
    4. The interface has become more intuitive.
    5. If the number of features is small (up to ten or eleven), it can provide exact Kernel SHAP values just like the reference Python implementation.
    6. For a larger number of features, it now uses partly-exact (“hybrid”) calculations, very similar to the logic in the Python implementation.

    With those changes, the R implementation is about to meet the Python version at eye level.

    Example with four features

    In the following, we use the diamonds data to fit a linear regression with

    • log(price) as response
    • log(carat) as numeric feature
    • clarity, color and cut as categorical features (internally dummy encoded)
    • interactions between log(carat) and the other three “C” variables. Note that the interactions are very weak

    Then, we calculate SHAP decompositions for about 1000 diamonds (every 53th diamond), using 120 diamonds as background dataset. In this case, both R and Python will use exact calculations based on m=2^4 – 2 = 14 possible binary on-off vectors (a value of 1 representing a feature value picked from the original observation, a value of 0 a value picked from the background data).

    library(ggplot2)
    library(kernelshap)
    
    # Turn ordinal factors into unordered
    ord <- c("clarity", "color", "cut")
    diamonds[, ord] <- lapply(diamonds[ord], factor, ordered = FALSE)
    
    # Fit model
    fit <- lm(log(price) ~ log(carat) * (clarity + color + cut), data = diamonds)
    
    # Subset of 120 diamonds used as background data
    bg_X <- diamonds[seq(1, nrow(diamonds), 450), ]
    
    # Subset of 1018 diamonds to explain
    X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)]
    
    # Exact KernelSHAP (5 seconds)
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X)  
    )
    ks
    
    # SHAP values of first 2 observations:
    #          carat     clarity     color        cut
    # [1,] -2.050074 -0.28048747 0.1281222 0.01587382
    # [2,] -2.085838  0.04050415 0.1283010 0.03731644
    
    # Using parallel backend
    library("doFuture")
    
    registerDoFuture()
    plan(multisession, workers = 2)  # Windows
    # plan(multicore, workers = 2)   # Linux, macOS, Solaris
    
    # 3 seconds on second call
    system.time(
      ks3 <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)  
    )
    
    # Visualization
    library(shapviz)
    
    sv <- shapviz(ks)
    sv_importance(sv, "bee")
    import numpy as np
    import pandas as pd
    from plotnine.data import diamonds
    from statsmodels.formula.api import ols
    from shap import KernelExplainer
    
    # Turn categoricals into integers because, inconveniently, kernel SHAP
    # requires numpy array as input
    ord = ["clarity", "color", "cut"]
    x = ["carat"] + ord
    diamonds[ord] = diamonds[ord].apply(lambda x: x.cat.codes)
    X = diamonds[x].to_numpy()
    
    # Fit model with interactions and dummy variables
    fit = ols(
      "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", 
      data=diamonds
    ).fit()
    
    # Background data (120 rows)
    bg_X = X[0:len(X):450]
    
    # Define subset of 1018 diamonds to explain
    X_small = X[0:len(X):53]
    
    # Calculate KernelSHAP values
    ks = KernelExplainer(
      model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), 
      data = bg_X
    )
    sv = ks.shap_values(X_small)  # 74 seconds
    sv[0:2]
    
    # array([[-2.05007406, -0.28048747,  0.12812216,  0.01587382],
    #        [-2.0858379 ,  0.04050415,  0.12830103,  0.03731644]])
    SHAP summary plot (R model)

    The results match, hurray!

    Example with nine features

    The computation effort of running exact Kernel SHAP explodes with the number of features. For nine features, the number of relevant on-off vectors is 2^9 – 2 = 510, i.e. about 36 times larger than with four features.

    We now modify above example, adding five additional features to the model. Note that the model structure is completely non-sensical. We just use it to get a feeling about what impact a 36 times larger workload has.

    Besides exact calculations, we use an almost exact hybrid approach for both R and Python, using 126 on-off vectors (p*(p+1) for the exact part and 4p for the sampling part, where p is the number of features), resulting in a significant speed-up both in R and Python.

    fit <- lm(
      log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth, 
      data = diamonds
    )
    
    # Subset of 1018 diamonds to explain
    X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]
    
    # Exact Kernel SHAP: 61 seconds
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X, exact = TRUE)  
    )
    ks
    #          carat        cut     color     clarity         depth         table          x           y            z
    # [1,] -1.842799 0.01424231 0.1266108 -0.27033874 -0.0007084443  0.0017787647 -0.1720782 0.001330275 -0.006445693
    # [2,] -1.876709 0.03856957 0.1266546  0.03932912 -0.0004202636 -0.0004871776 -0.1739880 0.001397792 -0.006560624
    
    # Default, using an almost exact hybrid algorithm: 17 seconds
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)  
    )
    #          carat        cut     color     clarity         depth         table          x           y            z
    # [1,] -1.842799 0.01424231 0.1266108 -0.27033874 -0.0007084443  0.0017787647 -0.1720782 0.001330275 -0.006445693
    # [2,] -1.876709 0.03856957 0.1266546  0.03932912 -0.0004202636 -0.0004871776 -0.1739880 0.001397792 -0.006560624
    x = ["carat"] + ord + ["table", "depth", "x", "y", "z"]
    X = diamonds[x].to_numpy()
    
    # Fit model with interactions and dummy variables
    fit = ols(
      "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color)) + table + depth + x + y + z", 
      data=diamonds
    ).fit()
    
    # Background data (120 rows)
    bg_X = X[0:len(X):450]
    
    # Define subset of 1018 diamonds to explain
    X_small = X[0:len(X):53]
    
    # Calculate KernelSHAP values: 12 minutes
    ks = KernelExplainer(
      model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), 
      data = bg_X
    )
    sv = ks.shap_values(X_small)
    sv[0:2]
    # array([[-1.84279897e+00, -2.70338744e-01,  1.26610769e-01,
    #          1.42423108e-02,  1.77876470e-03, -7.08444295e-04,
    #         -1.72078182e-01,  1.33027467e-03, -6.44569296e-03],
    #        [-1.87670887e+00,  3.93291219e-02,  1.26654599e-01,
    #          3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
    #         -1.73988040e-01,  1.39779179e-03, -6.56062359e-03]])
    
    # Now, using a hybrid between exact and sampling: 5 minutes
    sv = ks.shap_values(X_small, nsamples=126)
    sv[0:2]
    # array([[-1.84279897e+00, -2.70338744e-01,  1.26610769e-01,
    #          1.42423108e-02,  1.77876470e-03, -7.08444295e-04,
    #         -1.72078182e-01,  1.33027467e-03, -6.44569296e-03],
    #        [-1.87670887e+00,  3.93291219e-02,  1.26654599e-01,
    #          3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
    #         -1.73988040e-01,  1.39779179e-03, -6.56062359e-03]])

    Again, the results are essentially the same between R and Python, but also between the hybrid algorithm and the exact algorithm. This is interesting, because the hybrid algorithm is significantly faster than the exact one.

    Wrap-Up

    • R is catching up with Python’s superb “shap” package.
    • For two non-trivial linear regressions with interactions, the “kernelshap” package in R provides the same output as Python.
    • The hybrid between exact and sampling KernelSHAP (as implemented in Python and R) offers a very good trade-off between speed and accuracy.
    • kernelshap()in R is fast!

    The Python and R codes can be found here:

    The examples were run on a Windows notebook with an Intel i7-8650U 4 core CPU.