Category: Programming

  • Effect Plots in Python and R

    Christian and me did some code magic: Highly effective plots that help to build and inspect any model:

    The functionality is best described by its output:

    Python
    R

    The plots show different types of feature effects relevant in modeling:

    • Average observed: Descriptive effect (also interesting without model).
    • Average predicted: Combined effect of all features. Also called “M Plot” (Apley 2020).
    • Partial dependence: Effect of one feature, keeping other feature values constant (Friedman 2001).
    • Number of observations or sum of case weights: Feature value distribution.
    • R only: Accumulated local effects, an alternative to partial dependence (Apley 2020).

    Both implementations…

    • are highly efficient thanks to {Polars} in Python and {collapse} in R, and work on datasets with millions of observations,
    • support case weights with all their statistics, ideal in insurance applications,
    • calculate average residuals (not shown in the plots above),
    • provide standard deviations/errors of average observed and bias,
    • allow to switch to Plotly for interactive plots, and
    • are highly customizable (the R package, e.g., allows to collapse rare levels after calculating statistics via the update() method or to sort the features according to main effect importance).

    In the spirit of our “Lost In Translation” series, we provide both high-quality Python and R code. We will use the same data and models as in one of our latest posts on how to build strong GLMs via ML + XAI.

    Example

    Let’s build a Poisson LightGBM model to explain the claim frequency given six traditional features in a pricing dataset on motor liability claims. 80% of the 1 Mio rows are used for training, the other 20% for evaluation. Hyper-parameters have been slightly tuned (not shown).

    library(OpenML)
    library(lightgbm)
    
    dim(df <- getOMLDataSet(data.id = 45106L)$data)  # 1000000 7
    head(df)
    
    #   year town driver_age car_weight car_power car_age claim_nb
    # 0 2018    1         51       1760       173       3        0
    # 1 2019    1         41       1760       248       2        0
    # 2 2018    1         25       1240       111       2        0
    # 3 2019    0         40       1010        83       9        0
    # 4 2018    0         43       2180       169       5        0
    # 5 2018    1         45       1170       149       1        1
    
    yvar <- "claim_nb"
    xvars <- setdiff(colnames(df), yvar)
    
    ix <- 1:800000
    train <- df[ix, ]
    test <- df[-ix, ]
    X_train <- data.matrix(train[xvars])
    X_test <- data.matrix(test[xvars])
    
    # Training, using slightly optimized parameters found via cross-validation
    params <- list(
      learning_rate = 0.05,
      objective = "poisson",
      num_leaves = 7,
      min_data_in_leaf = 50,
      min_sum_hessian_in_leaf = 0.001,
      colsample_bynode = 0.8,
      bagging_fraction = 0.8,
      lambda_l1 = 3,
      lambda_l2 = 5,
      num_threads = 7
    )
    
    set.seed(1)
    
    fit <- lgb.train(
      params = params,
      data = lgb.Dataset(X_train, label = train$claim_nb),
      nrounds = 300
    )
    import matplotlib.pyplot as plt
    from lightgbm import LGBMRegressor
    from sklearn.datasets import fetch_openml
    
    df = fetch_openml(data_id=45106, parser="pandas").frame
    df.head()
    
    #   year town driver_age car_weight car_power car_age claim_nb
    # 0 2018    1         51       1760       173       3        0
    # 1 2019    1         41       1760       248       2        0
    # 2 2018    1         25       1240       111       2        0
    # 3 2019    0         40       1010        83       9        0
    # 4 2018    0         43       2180       169       5        0
    # 5 2018    1         45       1170       149       1        1
    
    # Train model on 80% of the data
    y = df.pop("claim_nb")
    n_train = 800_000
    X_train, y_train = df.iloc[:n_train], y.iloc[:n_train]
    X_test, y_test = df.iloc[n_train:], y.iloc[n_train:]
    
    params = {
        "learning_rate": 0.05,
        "objective": "poisson",
        "num_leaves": 7,
        "min_child_samples": 50,
        "min_child_weight": 0.001,
        "colsample_bynode": 0.8,
        "subsample": 0.8,
        "reg_alpha": 3,
        "reg_lambda": 5,
        "verbose": -1,
    }
    
    model = LGBMRegressor(n_estimators=300, **params, random_state=1)
    model.fit(X_train, y_train)

    Let’s inspect the (main effects) of the model on the test data.

    library(effectplots)
    
    # 0.3 s
    feature_effects(fit, v = xvars, data = X_test, y = test$claim_nb) |>
      plot(share_y = "all")
    from model_diagnostics.calibration import plot_marginal
    
    fig, axes = plt.subplots(3, 2, figsize=(8, 8), sharey=True, layout="tight")
    
    # 2.3 s
    for i, (feat, ax) in enumerate(zip(X_test.columns, axes.flatten())):
        plot_marginal(
            y_obs=y_test,
            y_pred=model.predict(X_test),
            X=X_test,
            feature_name=feat,
            predict_function=model.predict,
            ax=ax,
        )
        ax.set_title(feat)
        if i != 1:
            ax.legend().remove()

    The output can be seen at the beginning of this blog post.

    Here some model insights:

    • Average predictions closely match observed frequencies. No clear bias is visible.
    • Partial dependence shows that the year and the car weight almost have no impact (regarding their main effects), while the driver_age and car_power effects seem strongest. The shared y axes help to assess these.
    • Except for car_weight, the partial dependence curve closely follows the average predictions. This means that the model effect seems to really come from the feature on the x axis, and not of some correlated other feature (as, e.g., with car_weight which is actually strongly correlated with car_power).

    Final words

    • Inspecting models has become much relaxed with above functions.
    • The packages used offer much more functionality. Try them out! Or we will show them in later posts ;).

    References

    1. Apley, Daniel W., and Jingyu Zhu. 2020. Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models. Journal of the Royal Statistical Society Series B: Statistical Methodology, 82 (4): 1059–1086. doi:10.1111/rssb.12377.
    2. Friedman, Jerome H. 2001. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics 29 (5): 1189–1232. doi:10.1214/aos/1013203451.

    R script , Python notebook

  • Explaining a Causal Forest

    We use a causal forest [1] to model the treatment effect in a randomized controlled clinical trial. Then, we explain this black-box model with usual explainability tools. These will reveal segments where the treatment works better or worse, just like a forest plot, but multivariately.

    Data

    For illustration, we use patient-level data of a 2-arm trial of rectal indomethacin against placebo to prevent post-ERCP pancreatitis (602 patients) [2]. The dataset is available in the package {medicaldata}.

    The data is in fantastic shape, so we don’t need to spend a lot of time with data preparation.

    1. We integer encode factors.
    2. We select meaningful features, basically those shown in the forest plot of [2] (Figure 4) without low-information features and without hospital.

    The marginal estimate of the treatment effect is -0.078, i.e., indomethacin reduces the probability of post-ERCP pancreatitis by 7.8 percentage points. Our aim is to develop and interpret a model to see if this value is associated with certain covariates.

    library(medicaldata)
    suppressPackageStartupMessages(library(dplyr))
    library(grf)          #  causal_forest()
    library(ggplot2)
    library(patchwork)    #  Combine ggplots
    library(hstats)       #  Friedman's H, PDP
    library(kernelshap)   #  General SHAP
    library(shapviz)      #  SHAP plots
    
    W <- as.integer(indo_rct$rx) - 1L      # 0=placebo, 1=treatment
    table(W)
    #   0   1 
    # 307 295
    
    Y <- as.numeric(indo_rct$outcome) - 1  # Y=1: post-ERCP pancreatitis (bad)
    mean(Y)  # 0.1312292
    
    mean(Y[W == 1]) - mean(Y[W == 0])      # -0.07785568
    
    xvars <- c(
      "age",         # Age in years
      "male",        # Male (1=yes)
      "pep",         # Previous post-ERCP pancreatitis (1=yes)
      "recpanc",     # History of recurrent Pancreatitis (1=yes)
      "type",        # Sphincter of oddi dysfunction type/level (0=no, to 3=type 3)
      "difcan",      # Cannulation of the papilla was difficult (1=yes)
      "psphinc",     # Pancreatic sphincterotomy performed (1=yes)
      "bsphinc",     # Biliary sphincterotomy performed (1=yes)
      "pdstent",     # Pancreatic stent (1=yes)
      "train"        # Trainee involved in stenting (1=yes)
    )
    
    X <- indo_rct |>
      mutate_if(is.factor, function(v) as.integer(v) - 1L) |> 
      rename(male = gender) |> 
      select_at(xvars)
    
    head(X)
                
    # age  male   pep recpanc  type difcan psphinc bsphinc pdstent train
    #  26     0     0       1     1      0       0       0       0     1
    #  24     1     1       0     0      0       0       1       0     0
    #  57     0     0       0     2      0       0       0       0     0
    #  29     0     0       0     1      0       0       1       1     1
    #  38     0     1       0     1      0       1       1       1     1
    #  59     0     0       0     1      1       0       1       1     0
                
    summary(X)
                
    #     age             male             pep            recpanc     
    # Min.   :19.00   Min.   :0.0000   Min.   :0.0000   Min.   :0.000  
    # 1st Qu.:35.00   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.000  
    # Median :45.00   Median :0.0000   Median :0.0000   Median :0.000  
    # Mean   :45.27   Mean   :0.2093   Mean   :0.1595   Mean   :0.299  
    # 3rd Qu.:54.00   3rd Qu.:0.0000   3rd Qu.:0.0000   3rd Qu.:1.000  
    # Max.   :90.00   Max.   :1.0000   Max.   :1.0000   Max.   :1.000  
    #      type           difcan          psphinc          bsphinc      
    # Min.   :0.000   Min.   :0.0000   Min.   :0.0000   Min.   :0.0000  
    # 1st Qu.:1.000   1st Qu.:0.0000   1st Qu.:0.0000   1st Qu.:0.0000  
    # Median :2.000   Median :0.0000   Median :1.0000   Median :1.0000  
    # Mean   :1.743   Mean   :0.2608   Mean   :0.5698   Mean   :0.5714  
    # 3rd Qu.:2.000   3rd Qu.:1.0000   3rd Qu.:1.0000   3rd Qu.:1.0000  
    # Max.   :3.000   Max.   :1.0000   Max.   :1.0000   Max.   :1.0000  
    #    pdstent           train       
    # Min.   :0.0000   Min.   :0.0000  
    # 1st Qu.:1.0000   1st Qu.:0.0000  
    # Median :1.0000   Median :0.0000  
    # Mean   :0.8239   Mean   :0.4701  
    # 3rd Qu.:1.0000   3rd Qu.:1.0000  
    # Max.   :1.0000   Max.   :1.0000  

    The model

    We use the {grf} package to fit a causal forest [1], a tree-ensemble trying to estimate conditional average treatment effects (CATE) E[Y(1) – Y(0) | X = x]. As such, it can be used to study treatment effect inhomogeneity.

    In contrast to a typical random forest:

    • Honest trees are grown: Within trees, part of the data is used for splitting, and the other part for calculating the node values. This anti-overfitting is implemented for all random forests in {grf}.
    • Splits are selected to produce child nodes with maximally different treatment effects (under some additional constraints).

    Note: With about 13%, the complication rate is relatively low. Thus, the treatment effect (measured on absolute scale) can become small for certain segments simply because the complication rate is close to 0. Ideally, we could model relative treatment effects or odds ratios, but I have not found this option in {grf} so far.

    fit <- causal_forest(
      X = X,
      Y = Y,
      W = W,
      num.trees = 1000,
      mtry = 4,
      sample.fraction = 0.7,
      seed = 1,
      ci.group.size = 1,
    )

    Explain the model with “classic” techniques

    After looking at tree split importance, we study the effects via partial dependence plots and Friedman’s H. These only require a predict() function and a reference dataset.

    imp <- sort(setNames(variable_importance(fit), xvars))
    par(mai = c(0.7, 2, 0.2, 0.2))
    barplot(imp, horiz = TRUE, las = 1, col = "orange")
    
    pred_fun <- function(object, newdata, ...) {
      predict(object, newdata, ...)$predictions
    }
    
    pdps <- lapply(xvars, function(v) plot(partial_dep(fit, v, X = X, pred_fun = pred_fun)))
    wrap_plots(pdps, guides = "collect", ncol = 3) &
      ylim(c(-0.11, -0.06)) &
      ylab("Treatment effect")
                   
    H <- hstats(fit, X = X, pred_fun = pred_fun, verbose = FALSE)
    plot(H)
    
    partial_dep(fit, v = "age", X = X, BY = "bsphinc", pred_fun = pred_fun) |> 
      plot()

    Variable importance

    Variable importance of the causal forest can be measured by the relative counts each feature had been used to split on (in the first 4 levels). The most important variable is age.

    Main effects

    To study the main effects on the CATE, we consider partial dependence plots (PDP). Such plot shows how the average prediction depends on the values of a feature, keeping all other feature values constant (can be unnatural.)

    We can see that the treatment effect is strongest for persons up to age 35, then reduces until 45. For older patients, the effect increases again.

    Remember: Negative values mean a stronger (positive) treatment effect.

    Interaction strength

    Between what covariates are there strong interactions?

    A model agnostic way to assess pairwise interaction strength is Friedman’s H statistic [3]. It measures the error when approximating the two-dimensional partial dependence function of the two features by their univariate partial dependence functions. A value of zero means there is no interaction. A value of α means that about 100α% of the joint effect (variability) comes from the interaction.

    This measure is shown on the right hand side of the plot. More than 15% of the joint effect variability of age and biliary sphincterotomy (bsphinc) comes from their interaction.

    Typically, pairwise H-statistics are calculated only for the most important variables or those with high overall interaction strength. Overall interaction strength (left hand side of the plot) can be measured by a version of Friedman’s H. It shows how much of the prediction variability comes from interactions with that feature.

    Visualize strong interaction

    Interactions can be visualized, e.g., by a stratified PDP. We can see that the treatment effect is associated with age mainly for persons with biliary sphincterotomy.

    SHAP Analysis

    A “modern” way to explain the model is based on SHAP [4]. It decomposes the (centered) predictions into additive contributions of the covariates.

    Because there is no TreeSHAP shipped with {grf}, we use the much slower Kernel SHAP algorithm implemented in {kernelshap} that works for any model.

    First, we explain the prediction of a single data row, then we decompose many predictions. These decompositions can be analysed by simple descriptive plots to gain insights about the model as a whole.

    # Explaining one CATE
    kernelshap(fit, X = X[1, ], bg_X = X, pred_fun = pred_fun) |> 
      shapviz() |> 
      sv_waterfall() +
      xlab("Prediction")
    
    # Explaining all CATEs globally
    system.time(  # 13 min
      ks <- kernelshap(fit, X = X, pred_fun = pred_fun)  
    )
    shap_values <- shapviz(ks)
    
    sv_importance(shap_values)
    sv_importance(shap_values, kind = "bee")
    sv_dependence(shap_values, v = xvars) +
      plot_layout(ncol = 3) &
      ylim(c(-0.04, 0.03))

    Explain one CATE

    Explaining the CATE corresponding to the feature values of the first patient via waterfall plot.

    SHAP importance plot

    The bars show average absolute SHAP values. For instance, we can say that biliary sphincterotomy impacts the treatment effect on average by more than +- 0.01 (but we don’t see how).

    SHAP summary plot

    One-dimensional plot of SHAP values with scaled feature values on the color scale, sorted in the same order as the SHAP importance plot. Compared to the SHAP importance barplot, for instance, we can additionally see that biliary sphincterotomy weakens the treatment effect (positive SHAP value).

    SHAP dependence plots

    Scatterplots of SHAP values against corresponding feature values. Vertical scatter (at given x value) indicates presence of interactions. A candidate of an interacting feature is selected on the color scale. For instance, we see a similar pattern in the age effect on the treatment effect as in the partial dependence plot. Thanks to the color scale, we also see that the age effect depends on biliary sphincterotomy.

    Remember that SHAP values are on centered prediction scale. Still, a positive value means a weaker treatment effect.

    Wrap-up

    • {grf} is a fantastic package. You can expect more on it here.
    • Causal forests are an interesting way to directly model treatment effects.
    • Standard explainability methods can be used to explain the black-box.

    References

    1. Athey, Susan, Julie Tibshirani, and Stefan Wager. “Generalized Random Forests”. Annals of Statistics, 47(2), 2019.
    2. Elmunzer BJ et al. A randomized trial of rectal indomethacin to prevent post-ERCP pancreatitis. N Engl J Med. 2012 Apr 12;366(15):1414-22. doi: 10.1056/NEJMoa1111103.
    3. Friedman, Jerome H., and Bogdan E. Popescu. Predictive Learning via Rule Ensembles. The Annals of Applied Statistics 2, no. 3 (2008): 916-54.
    4. Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30 (2017).

    The full R notebook

  • Out-of-sample Imputation with {missRanger}

    {missRanger} is a multivariate imputation algorithm based on random forests, and a fast version of the original missForest algorithm of Stekhoven and Buehlmann (2012). Surprise, surprise: it uses {ranger} to fit random forests. Especially combined with predictive mean matching (PMM), the imputations are often quite realistic.

    Out-of-sample application

    The newest CRAN release 2.6.0 offers out-of-sample application. This is useful for removing any leakage between train/test data or during cross-validation. Furthermore, it allows to fill missing values in user provided data. By default, it uses the same number of PMM donors as during training, but you can change this by setting pmm.k = nice value.

    We distinguish two types of observations to be imputed:

    1. Easy case: Only a single value is missing. Here, we simply apply the corresponding random forest to fill the one missing value.
    2. Hard case: Multiple values are missing. Here, we first fill the values univariately, and then repeatedly apply the corresponding random forests, with the hope that the effect of univariate imputation vanishes. If values of two highly correlated features are missing, then the imputations can be non-sensical. There is no way to mend this.

    Example

    To illustrate the technique with a simple example, we use the iris data.

    1. First, we randomly add 10% missing values.
    2. Then, we make a train/test split.
    3. Next, we “fit” missRanger() to the training data.
    4. Finally, we use its new predict() method to fill the test data.

    library(missRanger)
    
    # 10% missings
    ir <- iris |> 
      generateNA(p = 0.1, seed = 11)
    
    # Train/test split stratified by Species
    oos <- c(1:10, 51:60, 101:110)
    train <- ir[-oos, ]
    test <- ir[oos, ]
    
    head(test)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3          NA  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4          NA          1.7          NA  setosa
    
    mr <- missRanger(train, pmm.k = 5, keep_forests = TRUE, seed = 1)
    test_filled <- predict(mr, test, seed = 1)
    head(test_filled)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3         0.2  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4         4.0          1.7         0.4  setosa
    
    # Original
    head(iris)
    
    #   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
    # 1          5.1         3.5          1.4         0.2  setosa
    # 2          4.9         3.0          1.4         0.2  setosa
    # 3          4.7         3.2          1.3         0.2  setosa
    # 4          4.6         3.1          1.5         0.2  setosa
    # 5          5.0         3.6          1.4         0.2  setosa
    # 6          5.4         3.9          1.7         0.4  setosa
    

    The results look reasonable, in this case even for the “hard case” row 6 with missing values in two variables. Here, it is probably the strong association with Species that helped to create good values.

    The new predict() also works with single row input.

    Learn more about {missRanger}

    The full R script

  • SHAP Values of Additive Models

    Within only a few years, SHAP (Shapley additive explanations) has emerged as the number 1 way to investigate black-box models. The basic idea is to decompose model predictions into additive contributions of the features in a fair way. Studying decompositions of many predictions allows to derive global properties of the model.

    What happens if we apply SHAP algorithms to additive models? Why would this ever make sense?

    In the spirit of our “Lost In Translation” series, we provide both high-quality Python and R code.

    The models

    Let’s build the models using a dataset with three highly correlated covariates and a (deterministic) response.

    library(lightgbm)
    library(kernelshap)
    library(shapviz)
    
    #===================================================================
    # Make small data
    #===================================================================
    
    make_data <- function(n = 100) {
      x1 <- seq(0.01, 1, length = n)
      data.frame(
        x1 = x1,
        x2 = log(x1),
        x3 = x1 > 0.7
      ) |>
        transform(y = 1 + 0.2 * x1 + 0.5 * x2 + x3 + sin(2 * pi * x1))
    }
    df <- make_data()
    head(df)
    cor(df) |>
      round(2)
    
    #      x1   x2   x3    y
    # x1 1.00 0.90 0.80 0.46
    # x2 0.90 1.00 0.58 0.58
    # x3 0.80 0.58 1.00 0.51
    # y  0.46 0.58 0.51 1.00
    
    #===================================================================
    # Additive linear model and additive boosted trees
    #===================================================================
    
    # Linear regression
    fit_lm <- lm(y ~ poly(x1, 3) + poly(x2, 3) + x3, data = df)
    summary(fit_lm)
    
    # Boosted trees
    xvars <- setdiff(colnames(df), "y")
    X <- data.matrix(df[xvars])
    
    params <- list(
      learning_rate = 0.05,
      objective = "mse",
      max_depth = 1,
      colsample_bynode = 0.7
    )
    
    fit_lgb <- lgb.train(
      params = params,
      data = lgb.Dataset(X, label = df$y),
      nrounds = 300
    )
    import numpy as np
    import lightgbm as lgb
    import shap
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.compose import ColumnTransformer
    from sklearn.pipeline import Pipeline
    from sklearn.linear_model import LinearRegression
    
    #===================================================================
    # Make small data
    #===================================================================
    
    def make_data(n=100):
        x1 = np.linspace(0.01, 1, n)
        x2 = np.log(x1)
        x3 = x1 > 0.7
        X = np.column_stack((x1, x2, x3))
    
        y = 1 + 0.2 * x1 + 0.5 * x2 + x3 + np.sin(2 * np.pi * x1)
        
        return X, y
    
    X, y = make_data()
    
    #===================================================================
    # Additive linear model and additive boosted trees
    #===================================================================
    
    # Linear model with polynomial terms
    poly = PolynomialFeatures(degree=3, include_bias=False)
    
    preprocessor =  ColumnTransformer(
        transformers=[
            ("poly0", poly, [0]),
            ("poly1", poly, [1]),
            ("other", "passthrough", [2]),
        ]
    )
    
    model_lm = Pipeline(
        steps=[
            ("preprocessor", preprocessor),
            ("lm", LinearRegression()),
        ]
    )
    _ = model_lm.fit(X, y)
    
    # Boosted trees with single-split trees
    params = dict(
        learning_rate=0.05,
        objective="mse",
        max_depth=1,
        colsample_bynode=0.7,
    )
    
    model_lgb = lgb.train(
        params=params,
        train_set=lgb.Dataset(X, label=y),
        num_boost_round=300,
    )

    SHAP

    For both models, we use exact permutation SHAP and exact Kernel SHAP. Furthermore, the linear model is analyzed with “additive SHAP”, and the tree-based model with TreeSHAP.

    Do the algorithms provide the same?

    system.time({  # 1s
      shap_lm <- list(
        add = shapviz(additive_shap(fit_lm, df)),
        kern = kernelshap(fit_lm, X = df[xvars], bg_X = df),
        perm = permshap(fit_lm, X = df[xvars], bg_X = df)
      )
    
      shap_lgb <- list(
        tree = shapviz(fit_lgb, X),
        kern = kernelshap(fit_lgb, X = X, bg_X = X),
        perm = permshap(fit_lgb, X = X, bg_X = X)
      )
    })
    
    # Consistent SHAP values for linear regression
    all.equal(shap_lm$add$S, shap_lm$perm$S)
    all.equal(shap_lm$kern$S, shap_lm$perm$S)
    
    # Consistent SHAP values for boosted trees
    all.equal(shap_lgb$lgb_tree$S, shap_lgb$lgb_perm$S)
    all.equal(shap_lgb$lgb_kern$S, shap_lgb$lgb_perm$S)
    
    # Linear coefficient of x3 equals slope of SHAP values
    tail(coef(fit_lm), 1)                # 1.112096
    diff(range(shap_lm$kern$S[, "x3"]))  # 1.112096
    
    sv_dependence(shap_lm$add, xvars)sv_dependence(shap_lm$add, xvars, color_var = NULL)
    shap_lm = {
        "add": shap.Explainer(model_lm.predict, masker=X, algorithm="additive")(X),
        "perm": shap.Explainer(model_lm.predict, masker=X, algorithm="exact")(X),
        "kern": shap.KernelExplainer(model_lm.predict, data=X).shap_values(X),
    }
    
    shap_lgb = {
        "tree": shap.Explainer(model_lgb)(X),
        "perm": shap.Explainer(model_lgb.predict, masker=X, algorithm="exact")(X),
        "kern": shap.KernelExplainer(model_lgb.predict, data=X).shap_values(X),
    }
    
    # Consistency for additive linear regression
    eps = 1e-12
    assert np.abs(shap_lm["add"].values - shap_lm["perm"].values).max() < eps
    assert np.abs(shap_lm["perm"].values - shap_lm["kern"]).max() < eps
    
    # Consistency for additive boosted trees
    assert np.abs(shap_lgb["tree"].values - shap_lgb["perm"].values).max() < eps
    assert np.abs(shap_lgb["perm"].values - shap_lgb["kern"]).max() < eps
    
    # Linear effect of last feature in the fitted model
    model_lm.named_steps["lm"].coef_[-1]  # 1.112096
    
    # Linear effect of last feature derived from SHAP values (ignore the sign)
    shap_lm["perm"][:, 2].values.ptp()    # 1.112096
    
    shap.plots.scatter(shap_lm["add"])
    SHAP dependence plot of the additive linear model and the additive explainer (Python).

    Yes – the three algorithms within model provide the same SHAP values. Furthermore, the SHAP values reconstruct the additive components of the features.

    Didactically, this is very helpful when introducing SHAP as a method: Pick a white-box and a black-box model and compare their SHAP dependence plots. For the white-box model, you simply see the additive components, while the dependence plots of the black-box model show scatter due to interactions.

    Remark: The exact equivalence between algorithms is lost, when

    • there are too many features for exact procedures (~10+ features), and/or when
    • the background data of Kernel/Permutation SHAP does not agree with the training data. This leads to slightly different estimates of the baseline value, which itself influences the calculation of SHAP values.

    Final words

    • SHAP algorithms applied to additive models typically give identical results. Slight differences might occur because sampling versions of the algos are used, or a different baseline value is estimated.
    • The resulting SHAP values describe the additive components.
    • Didactically, it helps to see SHAP analyses of white-box and black-box models side by side.

    R script , Python notebook

  • A Tweedie Trilogy — Part III: From Wrights Generalized Bessel Function to Tweedie’s Compound Poisson Distribution

    TLDR: The scipy 1.7.0 release introduced Wright’s generalized Bessel function in the Python ecosystem. It is an important ingredient for the density and log-likelihood of Tweedie probabilty distributions. In this last part of the trilogy I’d like to point out why it was important to have this function and share the endeavor of implementing this inconspicuous but highly intractable special function. The fun part is exploiting a free parameter in an integral representation, which can be optimized by curve fitting to the minimal arc length.

    This trilogy celebrates the 40th birthday of Tweedie distributions in 2024 and highlights some of their very special properties.

    See part i and part ii.

    Tweedie Distributions

    As pointed out in part I and part II, the family of Tweedie distributions is a very special one with outstanding properties. They are central for estimating expectations with GLMs. The probability distributions have mainly positive (non-negative) support and are skewed, e.g. Poisson, Gamma, Inverse Gaussian and compound Poisson-Gamma.

    As members of the exponential dispersion family, a slight extension of the exponential family, the probability density can be written as

    \begin{align*}
    f(y; \theta, \phi) &= c(y, \phi) \exp\left(\frac{y\theta - \kappa(\theta)}{\phi}\right)
    \\
    \kappa(\theta) &= \kappa_p(\theta) = \frac{1}{2-p}((1-p)\theta)^{\frac{2-p}{1-p}}
    \end{align*}

    It is often more instructive to parametrise the distribution with p, \mu and \phi, using

    \begin{align*}
    \theta &= \begin{cases}
    \frac{\mu^{1-p}}{1-p}\,,\quad p\neq 1\\
    \log(\mu)\,,\quad p=1
    \end{cases}
    \\
    \kappa(\theta) &= \begin{cases}
    \frac{\mu^{2-p}}{2-p}\,,\quad p\neq 2\\
    \log(\mu)\,,\quad p=2
    \end{cases}
    \end{align*}

    and write

    \begin{align*}
    Y &\sim \mathrm{Tw}_p(\mu, \phi)
    \end{align*}
    Probability density of several Tweedie distributions.

    Compound Poisson Gamma

    A very special domain for the power parameter is between Poisson and Gamma: 1<p<2. This range results in the Compound Poisson distribution which is suitable if you have a random count process and if each count itself has a random amount. A well know example is insurance claims. Typically, there is a random number of insurance claims, and each and every claim has a random amount of claim costs.

    \begin{align*}
    N &\sim \mathrm{Poisson}(\lambda)\\
    X_i &\sim \mathrm{Gamma}(a, b)\\
    Y &= \sum_{i=0}^N X_i \sim \mathrm{CompPois}(\lambda, a, b)
    \end{align*}

    For Poisson count we have \operatorname{E}[N]=\lambda and \operatorname{Var}[N]=\lambda=\operatorname{E}[N], for the Gamma amount \operatorname{E}[X]=\frac{a}{b} and \operatorname{Var}[X]=\frac{a}{b^2}=\frac{1}{a}\operatorname{E}[X]^2. For the compound Poisson-Gamma variable, we obtain

    \begin{align*}
    \operatorname{E}[Y] &= \operatorname{E}[N] \operatorname{E}[X] = \lambda\frac{a}{b}=\mu\\
    \operatorname{Var}[Y] &=  \operatorname{Var}[N] \operatorname{E}[X]^2 +  \operatorname{E}[N] \operatorname{Var}[X] =  \phi \mu^p\\
    p &= \frac{a + 2}{a+1} \in (1, 2)\\
    \phi &= \frac{(\lambda a)^{1-p}}{(p-1)b^{2-p}}
    \end{align*}

    What’s so special here is that there is a point mass at zero, i.e., P(Y=0)=\exp(-\frac{\mu^{2-p}}{\phi(2-p)}) > 0. Hence, it is a suitable distribution for non-negative quantities with some exact zeros.

    Probability density for compound Poisson Gamma, point masses at zero are marked as points.

    Code

    The rest of this post is about how to compute the density for this parameter range. The easy part is \exp\left(\frac{y\theta - \kappa(\theta)}{\phi}\right) which can be directly implemented. The real obstacle is the term c(y, \phi) which is given by

    \begin{align*}
    c(y, \phi) &= \frac{\Phi(-\alpha, 0, t)}{y}
    \\
    \alpha &= \frac{2 - p}{1 - p}
    \\
    t &= \frac{\left(\frac{(p - 1)\phi}{y}\right)^{\alpha}}{(2-p)\phi}
    \end{align*}

    This depends on Wright’s (generalized Bessel) function \Phi(a, b, z) as introduced in a 1933 paper by E. Wright.

    Wright’s Generalized Bessel Function

    According to DLMF 10.46, the function is defined as

    \begin{equation*}
    \Phi(a, b, z) = \sum_{k=0}^{\infty} \frac{z^k}{k!\Gamma(ak+b)}, \quad a > -1, b \in R, z \in C
    \end{equation*}

    which converges everywhere because it is an entire function. We will focus on the positive real axis z=x\geq 0 and the range a\geq 0, b\geq 0 (note that a=-\alpha \in (0,\infty) for 1<p<2). For the compound Poisson-Gamma, we even have b=0.

    Implementation of such a function as done in scipy.stats.wright_bessel, even for the restricted parameter range, poses tremendous challenges. The first one is that it has three parameters which is quite a lot. Then the series representation above, for instance, can always be used, but depending on the parameters, it will require a huge amount of terms, particularly for large x. As each term involves the Gamma function, this becomes expensive very fast. One ends up using different representations and strategies for different parameter regions:

    • Small x: Taylor series according to definition
    • Small a: Taylor series in a=0
    • Large x: Asymptotic series due to Wright (1935)
    • Large a: Taylor series according to definition for a few terms around the approximate maximum term k_{max} due to Dunn & Smyth (2005)
    • General: Integral represantation due to Luchko (2008)

    Dunn & Smyth investigated several evaluation strategies for the simpler Tweedie density which amounts to Wright’s functions with b=0, see Dunn & Smyth (2005). Luchko (2008) lists most of the above strategies for the full Wright’s function.

    Note that Dunn & Smyth (2008) provide another strategy to evaluate the Tweedie distribution function by means of the inverse Fourier transform. This does not involve Wright’s function, but also encounters complicated numerical integration of oscillatory functions.

    The Integral Representation

    This brings us deep into complex analysis: We start with Hankel’s contour integral representation of the reciprocal Gamma function.

    \begin{equation*}
    \frac{1}{\Gamma(z)} = \frac{1}{2\pi i} \int_{Ha^-} \zeta^{-z} e^\zeta \; d\zeta
    \end{equation*}

    with the Hankel path Ha^- from negative infinity (A) just below the real axis, counter-clockwise with radius \epsilon>0 around the origin and just above the real axis back to minus infinity (D).

    Hankel contour Ha in the complex plane.

    In principle, one is free to choose any such path with the same start (A) and end point (D) as long as one does not cross the negative real axis. One usually lets the AB and CD be infinitesimal close to the negative real line. Very importantly, the radius \epsilon>0 is a free parameter! That is real magic🪄

    By interchanging sum and integral and using the series of the exponential, Wright’s function becomes

    \begin{align*}
    \Phi(a, b, z) &= \sum_{k=0}^{\infty} \frac{z^k}{k!} \frac{1}{2\pi i} \int_{Ha^-} \zeta^{-(ak+b)} e^\zeta \; d\zeta
    \\
    &= \frac{1}{2\pi i} \int_{Ha^-} \zeta^{-b} e^{\zeta + z\zeta^{-a}} \; d\zeta
    \end{align*}

    Now, one needs to do the tedious work and split the integral into the 3 path sections AB, BC, CD. Putting AB and CD together gives an integral over K, the circle BC gives an integral over P:

    \begin{align*}
    \Phi(a, b, x) &= \frac{1}{\pi} \int_{\epsilon}^\infty K(a, b, x, r) \; dr
    \\
     &+ \frac{\epsilon^{1-b}}{\pi} \int_0^\pi P(\epsilon, a, b, x, \varphi) \; d\varphi
    \\
    K(a, b, x, r) &= r^{-b}\exp(-r + x  r^{-a} \cos(\pi a)) 
    \\
    &\quad \sin(x \cdot r^{-a} \sin(\pi a) + \pi b)
    \\
    P(\epsilon, a, b, x, \varphi) &= \exp(\epsilon \cos(\varphi) + x  \epsilon^{-a}\cos(a \varphi))
    \\
    &\quad \cos(\epsilon \sin(\varphi) - x \cdot \epsilon^{-a} \sin(a \varphi) + (1-b) \varphi)
    \end{align*}

    What remains is to carry out the numerical integration, also known as quadrature. While this is an interesting topic in its own, let’s move to the magic part.

    Arc Length Minimization

    If you have come so far and say, wow, puh, uff, crazy, 🤯😱 Just keep on a little bit because here comes the real fun part🤞

    It turns out that most of the time, the integral over P is the most difficult. The worst behaviour an integrand can have is widely oscillatory. Here is one of my favorite examples:

    Integrands for a=5, b=1, x=100 and two choices of epsilon.

    With the naive choice of \epsilon=1, both integrands (blue) are—well—crazy. There is basically no chance the most sophisticated quadrature rule will work. And then look at the other choice of \epsilon\approx 4. Both curves seem well behaved (for P, we would need a closer look).

    So the idea is to find a good choice of \epsilon to make P well behaved. Well behaved here means most boring, if possible a straight line. What makes a straight line unique? In flat space, it is the shortest path between two points. Therefore, well behaved integrands have minimal arc length. That is what we want to minimize.

    The arc length S from x=a to x=b of a 1-dimensional function f is given by

    \begin{equation*}
    S = \int_a^b \sqrt{1 + f^\prime(x)^2} \; dx
    \end{equation*}

    Instead of f=P, we only take the oscillatory part of P and approximate the arc length as f(\varphi)=f(\varphi) = \epsilon \sin(\varphi) - x \epsilon^{-\rho} \sin(\rho \varphi) + (1-\beta) \varphi. For a single parameter point a, b, z this looks like

    Arc length and integrand P for different epsilon, given a=0.1, b=5, x=100.

    Note the logarithmic y-scale for the right plot of P. The optimal \epsilon=10 is plotted in red and behaves clearly better than smaller values of \epsilon.

    What remains to be done for an actual implementation is

    • Calculate minimal \epsilon for a large grid of values a, b, x.
    • Choose a function with some parameters.
    • Curve fitting (so again optimisation): Fit this function to the minimal \epsilon of the grid via minimising least squares.
    • Implement some quadrature rules and use this choice of \epsilon in the hope that it intra- and extrapolates well.

    This strategy turns out to work well in practice and is implemented in scipy. As the parameter space of 3 variables is huge, the integral representation breaks down in certain areas, e.g. huge values of \epsilon where the integrands just overflow numerically (in 64-bit floating point precision). But we have our other evaluation strategies for that.

    Conclusion

    An extensive notebook for Wright’s function, with all implementation strategies can be found here.

    After an adventurous journey, we arrived at one implementation strategy of Wright’s generalised Bessel function, namely the integral representation. The path went deep into complex analysis and contour integration, then further to the arc length of a function and finally curve fitting via optimisation. I am really astonished how connected all those different areas of mathematics can be.

    Wright’s function is the missing piece to compute full likelihoods and probability functions of the Tweedie distribution family and is now available in the Python ecosystem via scipy.

    We are at the very end of this Tweedie trilogy. I hope it has been entertaining and it has become clear why Tweedie deserves to be celebrated.

    Further references:

  • Building Strong GLMs in Python via ML + XAI

    In our latest post, we explained how to use ML + XAI to build strong generalized linear models with R. Let’s do the same with Python.

    Insurance pricing data

    We will use again a synthetic dataset with 1 Mio insurance policies, with reference:

    Mayer, M., Meier, D. and Wuthrich, M.V. (2023),
    SHAP for Actuaries: Explain any Model.
    https://doi.org/10.2139/ssrn.4389797

    Let’s start by loading and describing the data:

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import shap
    from sklearn.datasets import fetch_openml
    from sklearn.inspection import PartialDependenceDisplay
    from sklearn.metrics import mean_poisson_deviance
    from sklearn.dummy import DummyRegressor
    from lightgbm import LGBMRegressor
    # We need preview version of glum that adds formulaic API
    # !pip install git+https://github.com/Quantco/glum@glum-v3#egg=glum
    from glum import GeneralizedLinearRegressor
    
    # Load data
    
    df = fetch_openml(data_id=45106, parser="pandas").frame
    df.head()
    
    # Continuous features
    df.hist(["driver_age", "car_weight", "car_power", "car_age"])
    _ = plt.suptitle("Histograms of continuous features", fontsize=15)
    
    # Response and discrete features
    fig, axes = plt.subplots(figsize=(8, 3), ncols=3)
    
    for v, ax in zip(["claim_nb", "year", "town"], axes):
        df[v].value_counts(sort=False).sort_index().plot(kind="bar", ax=ax, rot=0, title=v)
    plt.suptitle("Barplots of response and discrete features", fontsize=15)
    plt.tight_layout()
    plt.show()
    
    # Rank correlations
    corr = df.corr("spearman")
    mask = np.triu(np.ones_like(corr, dtype=bool))
    plt.suptitle("Rank-correlogram", fontsize=15)
    _ = sns.heatmap(
        corr, mask=mask, vmin=-0.7, vmax=0.7, center=0, cmap="vlag", square=True
    )
    
    

    Modeling

    1. We fit a tuned Boosted Trees model to model log(E(claim count)) via Poisson deviance loss.
    2. And perform a SHAP analysis to derive insights.
    from sklearn.model_selection import train_test_split
    
    X_train, X_test, y_train, y_test = train_test_split(
        df.drop("claim_nb", axis=1), df["claim_nb"], test_size=0.1, random_state=30
    )
    
    # Tuning step not shown. Number of boosting rounds found via early stopping on CV performance
    params = dict(
        learning_rate=0.05,
        objective="poisson",
        num_leaves=7,
        min_child_samples=50,
        min_child_weight=0.001,
        colsample_bynode=0.8,
        subsample=0.8,
        reg_alpha=3,
        reg_lambda=5,
        verbose=-1,
    )
    
    model_lgb = LGBMRegressor(n_estimators=360, **params)
    model_lgb.fit(X_train, y_train)
    
    # SHAP analysis
    X_explain = X_train.sample(n=2000, random_state=937)
    explainer = shap.Explainer(model_lgb)
    shap_val = explainer(X_explain)
    
    plt.suptitle("SHAP importance", fontsize=15)
    shap.plots.bar(shap_val)
    
    for s in [shap_val[:, 0:3], shap_val[:, 3:]]:
        shap.plots.scatter(s, color=shap_val, ymin=-0.5, ymax=1)

    Here, we would come to the conclusions:

    1. car_weight and year might be dropped, depending on the specify aim of the model.
    2. Add a regression spline for driver_age.
    3. Add an interaction between car_power and town.

    Build strong GLM

    Let’s build a GLM with these insights. Two important things:

    1. Glum is an extremely powerful GLM implementation that was inspired by a pull request of our Christian Lorentzen.
    2. In the upcoming version 3.0, it adds a formula API based of formulaic, a very performant formula parser. This gives a very easy way to add interaction effects, regression splines, dummy encodings etc.
    model_glm = GeneralizedLinearRegressor(
        family="poisson",
        l1_ratio=1.0,
        alpha=1e-10,
        formula="car_power * C(town) + bs(driver_age, 7) + car_age",
    )
    model_glm.fit(X_train, y=y_train)  # 1 second on old laptop
    
    # PDPs of both models
    fig, ax = plt.subplots(2, 2, figsize=(7, 5))
    cols = ("tab:blue", "tab:orange")
    for color, name, model in zip(cols, ("GLM", "LGB"), (model_glm, model_lgb)):
        disp = PartialDependenceDisplay.from_estimator(
            model,
            features=["driver_age", "car_age", "car_power", "town"],
            X=X_explain,
            ax=ax if name == "GLM" else disp.axes_,
            line_kw={"label": name, "color": color},
        )
    fig.suptitle("PDPs of both models", fontsize=15)
    fig.tight_layout()
    
    # Stratified PDP of car_power
    for color, town in zip(("tab:blue", "tab:orange"), (0, 1)):
        mask = X_explain.town == town
        disp = PartialDependenceDisplay.from_estimator(
            model_glm,
            features=["car_power"],
            X=X_explain[mask],
            ax=None if town == 0 else disp.axes_,
            line_kw={"label": town, "color": color},
        )
    plt.suptitle("PDP of car_power stratified by town (0 vs 1)", fontsize=15)
    _ = plt.ylim(0, 0.2)

    In this relatively simple situation, the mean Poisson deviance of our models are very simlar now:

    model_dummy = DummyRegressor().fit(X_train, y=y_train)
    deviance_null = mean_poisson_deviance(y_test, model_dummy.predict(X_test)) 
    
    dev_imp = []
    for name, model in zip(("GLM", "LGB", "Null"), (model_glm, model_lgb, model_dummy)):
        dev_imp.append((name, mean_poisson_deviance(y_test, model.predict(X_test))))
    pd.DataFrame(dev_imp, columns=["Model", "Mean_Poisson_Deviance"])

    Final words

    • Glum is an extremely powerful GLM implementation – we have only scratched its surface. You can expect more blogposts on Glum…
    • Having a formula interface is especially useful for adding interactions. Fingers crossed that the upcoming version 3.0 will soon be released.
    • Building GLMs via ML + XAI is so smooth, especially when you work with large data. For small data, you need to be careful to not add hidden overfitting to the model.

    Click here for the full Python notebook

  • ML + XAI -> Strong GLM

    My last post was using {hstats}, {kernelshap} and {shapviz} to explain a binary classification random forest. Here, we use the same package combo to improve a Poisson GLM with insights from a boosted trees model.

    Insurance pricing data

    This time, we work with a synthetic, but quite realistic dataset. It describes 1 Mio insurance policies and their corresponding claim counts. A reference for the data is:

    Mayer, M., Meier, D. and Wuthrich, M.V. (2023),
    SHAP for Actuaries: Explain any Model.
    http://dx.doi.org/10.2139/ssrn.4389797

    library(OpenML)
    library(lightgbm)
    library(splines)
    library(ggplot2)
    library(patchwork)
    library(hstats)
    library(kernelshap)
    library(shapviz)
    
    #===================================================================
    # Load and describe data
    #===================================================================
    
    df <- getOMLDataSet(data.id = 45106)$data
    
    dim(df)  # 1000000       7
    head(df)
    
    # year town driver_age car_weight car_power car_age claim_nb
    # 2018    1         51       1760       173       3        0
    # 2019    1         41       1760       248       2        0
    # 2018    1         25       1240       111       2        0
    # 2019    0         40       1010        83       9        0
    # 2018    0         43       2180       169       5        0
    # 2018    1         45       1170       149       1        1
    
    summary(df)
    
    # Response
    ggplot(df, aes(claim_nb)) +
      geom_bar(fill = "chartreuse4") +
      ggtitle("Distribution of the response")
    
    # Features
    xvars <- c("year", "town", "driver_age", "car_weight", "car_power", "car_age")
    
    df[xvars] |> 
      stack() |> 
    ggplot(aes(values)) +
      geom_histogram(fill = "chartreuse4", bins = 19) +
      facet_wrap(~ind, scales = "free", ncol = 2) +
      ggtitle("Distribution of the features")
    
    # car_power and car_weight are correlated 0.68, car_age and driver_age 0.28
    df[xvars] |> 
      cor() |> 
      round(2)
    #            year  town driver_age car_weight car_power car_age
    # year          1  0.00       0.00       0.00      0.00    0.00
    # town          0  1.00      -0.16       0.00      0.00    0.00
    # driver_age    0 -0.16       1.00       0.09      0.10    0.28
    # car_weight    0  0.00       0.09       1.00      0.68    0.00
    # car_power     0  0.00       0.10       0.68      1.00    0.09
    # car_age       0  0.00       0.28       0.00      0.09    1.00
    
    

    Modeling

    1. We fit a naive additive linear GLM and a tuned Boosted Trees model.
    2. We combine the models and specify their predict function.
    # Train/test split
    set.seed(8300)
    ix <- sample(nrow(df), 0.9 * nrow(df))
    train <- df[ix, ]
    valid <- df[-ix, ]
    
    # Naive additive linear Poisson regression model
    (fit_glm <- glm(claim_nb ~ ., data = train, family = poisson()))
    
    # Boosted trees with LightGBM. The parameters (incl. number of rounds) have been 
    # by combining early-stopping with random search CV (not shown here)
    
    dtrain <- lgb.Dataset(data.matrix(train[xvars]), label = train$claim_nb)
    
    params <- list(
      learning_rate = 0.05, 
      objective = "poisson", 
      num_leaves = 7, 
      min_data_in_leaf = 50, 
      min_sum_hessian_in_leaf = 0.001, 
      colsample_bynode = 0.8, 
      bagging_fraction = 0.8, 
      lambda_l1 = 3, 
      lambda_l2 = 5
    )
    
    fit_lgb <- lgb.train(params = params, data = dtrain, nrounds = 300)  
    
    # {hstats} works for multi-output predictions,
    # so we can combine all models to a list, which simplifies the XAI part.
    models <- list(GLM = fit_glm, LGB = fit_lgb)
    
    # Custom predictions on response scale
    pf <- function(m, X) {
      cbind(
        GLM = predict(m$GLM, X, type = "response"),
        LGB = predict(m$LGB, data.matrix(X[xvars]))
      )
    }
    pf(models, head(valid, 2))
    #       GLM        LGB
    # 0.1082285 0.08580529
    # 0.1071895 0.09181466
    
    # And on log scale
    pf_log <- function(m, X) {
      log(pf(m = m, X = X))
    }
    pf_log(models, head(valid, 2))
    #       GLM       LGB
    # -2.223510 -2.455675
    # -2.233157 -2.387983 -2.346350

    Traditional XAI

    Performance

    Comparing average Poisson deviance on the validation data shows that the LGB model is clearly better than the naively built GLM, so there is room for improvent!

    perf <- average_loss(
      models, X = valid, y = "claim_nb", loss = "poisson", pred_fun = pf
    )
    perf
    #       GLM       LGB 
    # 0.4362407 0.4331857
    

    Feature importance

    Next, we calculate permutation importance on the validation data with respect to mean Poisson deviance loss. The results make sense, and we note that year and car_weight seem to be negligile.

    imp <- perm_importance(
      models, v = xvars, X = valid, y = "claim_nb", loss = "poisson", pred_fun = pf
    )
    plot(imp)

    Main effects

    Next, we visualize estimated main effects by partial dependence plots on log link scale. The differences between the models are quite small, with one big exception: Investing more parameters into driver_age via spline will greatly improve the performance and usefulness of the GLM.

    partial_dep(models, v = "driver_age", train, pred_fun = pf_log) |> 
      plot(show_points = FALSE)
    
    pdp <- function(v) {
      partial_dep(models, v = v, X = train, pred_fun = pf_log) |> 
        plot(show_points = FALSE)
    }
    wrap_plots(lapply(xvars, pdp), guides = "collect") &
      ylim(-2.8, -1.7)

    Interaction effects

    Friedman’s H-squared (per feature and feature pair) and on log link scale shows that – unsurprisingly – our GLM does not contain interactions, and that the strongest relative interaction happens between town and car_power. The stratified PDP visualizes this interaction. Let’s add a corresponding interaction effect to our GLM later.

    system.time(  # 5 sec
      H <- hstats(models, v = xvars, X = train, pred_fun = pf_log)
    )
    H
    plot(H)
    
    # Visualize strongest interaction by stratified PDP
    partial_dep(models, v = "car_power", X = train, pred_fun = pf_log, BY = "town") |> 
      plot(show_points = FALSE)

    SHAP

    As an elegant alternative to studying feature importance, PDPs and Friedman’s H, we can simply run a SHAP analysis on the LGB model.

    set.seed(22)
    X_explain <- train[sample(nrow(train), 1000), xvars]
     
    shap_values_lgb <- shapviz(fit_lgb, data.matrix(X_explain))
    sv_importance(shap_values_lgb)
    sv_dependence(shap_values_lgb, v = xvars) &
      ylim(-0.35, 0.8)

    Here, we would come to the same conclusions:

    1. car_weight and year might be dropped.
    2. Add a regression spline for driver_age
    3. Add an interaction between car_power and town.

    Pimp the GLM

    In the final section, we apply the three insights from above with very good results.

    fit_glm2 <- glm(
      claim_nb ~ car_power * town + ns(driver_age, df = 7) + car_age, 
      data = train, 
      family = poisson()
      
    # Performance now as good as LGB
    perf_glm2 <- average_loss(
      fit_glm2, X = valid, y = "claim_nb", loss = "poisson", type = "response"
    )
    perf_glm2  # 0.432962
    
    # Effects similar as LGB, and smooth
    partial_dep(fit_glm2, v = "driver_age", X = train) |> 
      plot(show_points = FALSE)
    
    partial_dep(fit_glm2, v = "car_power", X = train, BY = "town") |> 
      plot(show_points = FALSE)

    Or even via permutation or kernel SHAP:

    set.seed(1)
    bg <- train[sample(nrow(train), 200), ]
    xvars2 <- setdiff(xvars, c("year", "car_weight"))
    
    system.time(  # 4 sec
      ks_glm2 <- permshap(fit_glm2, X = X_explain[xvars2], bg_X = bg)
    )
    shap_values_glm2 <- shapviz(ks_glm2)
    sv_dependence(shap_values_glm2, v = xvars2) &
      ylim(-0.3, 0.8)

    Final words

    • Improving naive GLMs with insights from ML + XAI is fun.
    • In practice, the gap between GLM and a boosted trees model can’t be closed that easily. (The true model behind our synthetic dataset contains a single interaction, unlike real data/models that typically have much more interactions.)
    • {hstats} can work with multiple regression models in parallel. This helps to keep the workflow smooth. Similar for {kernelshap}.
    • A SHAP analysis often brings the same qualitative insights as multiple other XAI tools together.

    The full R script

  • Explain that tidymodels blackbox!

    Let’s explain a {tidymodels} random forest by classic explainability methods (permutation importance, partial dependence plots (PDP), Friedman’s H statistics), and also fancy SHAP.

    Disclaimer: {hstats}, {kernelshap} and {shapviz} are three of my own packages.

    Diabetes data

    We will use the diabetes prediction dataset of Kaggle to model diabetes (yes/no) as a function of six demographic features (age, gender, BMI, hypertension, heart disease, and smoking history). It has 100k rows.

    Note: The data additionally contains the typical diabetes indicators HbA1c level and blood glucose level, but we wont use them to avoid potential causality issues, and to gain insights also for people that do not know these values.

    # https://www.kaggle.com/datasets/iammustafatz/diabetes-prediction-dataset
    
    library(tidyverse)
    library(tidymodels)
    library(hstats)
    library(kernelshap)
    library(shapviz)
    library(patchwork)
    
    df0 <- read.csv("diabetes_prediction_dataset.csv")  # from above Kaggle link
    dim(df0)  # 100000 9
    head(df0)
    # gender age hypertension heart_disease smoking_history   bmi HbA1c_level blood_glucose_level diabetes
    # Female  80            0             1           never 25.19         6.6                 140        0
    # Female  54            0             0         No Info 27.32         6.6                  80        0
    #   Male  28            0             0           never 27.32         5.7                 158        0
    # Female  36            0             0         current 23.45         5.0                 155        0
    #   Male  76            1             1         current 20.14         4.8                 155        0
    # Female  20            0             0           never 27.32         6.6                  85        0
    
    summary(df0)
    anyNA(df0)  # FALSE
    table(df0$smoking_history, useNA = "ifany")
    
    # DATA PREPARATION
    
    # Note: tidymodels needs a factor response for classification
    df1 <- df0 |>
      transform(
        y = factor(diabetes, levels = 0:1, labels = c("No", "Yes")),
        female = (gender == "Female") * 1,
        smoking_history = factor(
          smoking_history, 
          levels = c("No Info", "never", "former", "not current", "current", "ever")
        ),
        bmi = pmin(bmi, 50)
      )
    
    # UNIVARIATE ANALYSIS
    
    ggplot(df1, aes(diabetes)) +
      geom_bar(fill = "chartreuse4")
    
    df1  |>  
      select(age, bmi, HbA1c_level, blood_glucose_level) |> 
      pivot_longer(everything()) |> 
      ggplot(aes(value)) +
      geom_histogram(fill = "chartreuse4", bins = 19) +
      facet_wrap(~ name, scale = "free_x")
    
    ggplot(df1, aes(smoking_history)) +
      geom_bar(fill = "chartreuse4")
    
    df1 |> 
      select(heart_disease, hypertension, female) |>
      pivot_longer(everything()) |> 
      ggplot(aes(name, value)) +
      stat_summary(fun = mean, geom = "bar", fill = "chartreuse4") +
      xlab(element_blank())
    
    “yes” proportion of binary variables (including the response)
    Distribution of numeric variables
    Distribution of smoking_history

    Modeling

    Let’s fit a random forest via tidymodels with {ranger} backend.

    We add a predict function pf() that outputs only the probability of the “Yes” class.

    set.seed(1)
    ix <- initial_split(df1, strata = diabetes, prop = 0.8)
    train <- training(ix)
    test <- testing(ix)
    
    xvars <- c("age", "bmi", "smoking_history", "heart_disease", "hypertension", "female")
    
    rf_spec <- rand_forest(trees = 500) |> 
      set_mode("classification") |> 
      set_engine("ranger", num.threads = NULL, seed = 49)
    
    rf_wf <- workflow() |> 
      add_model(rf_spec) |>
      add_formula(reformulate(xvars, "y"))
    
    model <- rf_wf |> 
        fit(train)
    
    # predict() gives No/Yes columns
    predict(model, head(test), type = "prob")
    # .pred_No .pred_Yes
    #    0.981    0.0185
    
    # We need to extract only the "Yes" probabilities
    pf <- function(m, X) {
      predict(m, X, type = "prob")$.pred_Yes
    }
    pf(model, head(test))  # 0.01854290 ...
    

    Classic explanation methods

    # 4 times repeated permutation importance wrt test logloss
    imp <- perm_importance(
      model, X = test, y = "diabetes", v = xvars, pred_fun = pf, loss = "logloss"
    )
    plot(imp) +
      xlab("Increase in test logloss")
    
    # Partial dependence of age
    partial_dep(model, v = "age", train, pred_fun = pf) |> 
      plot()
    
    # All PDP in one patchwork
    p <- lapply(xvars, function(x) plot(partial_dep(model, v = x, X = train, pred_fun = pf)))
    wrap_plots(p) &
      ylim(0, 0.23) &
      ylab("Probability")
    
    # Friedman's H stats
    system.time( # 20 s
      H <- hstats(model, train[xvars], approx = TRUE, pred_fun = pf)
    )
    H  # 15% of prediction variability comes from interactions
    plot(H)
    
    # Stratified PDP of strongest interaction
    partial_dep(model, "age", BY = "bmi", X = train, pred_fun = pf) |> 
      plot(show_points = FALSE)

    Feature importance

    Permutation importance measures by how much the average test loss (in our case log loss) increases when a feature is shuffled before calculating the losses. We repeat the process four times and also show standard errors.

    Permutation importance: Age and BMI are the two main risk factors.

    Main effects

    Main effects are estimated by PDP. They show how the average prediction changes with a feature, keeping every other feature fixed. Using a fixed vertical axis helps to grasp the strenght of the effect.

    PDPs: The diabetes risk tends to increase with age, high (and very low) BMI, presence of heart disease/hypertension, and it is a bit lower for females and non-smoker.

    Interaction strength

    Interaction strength can be measured by Friedman’s H statistics, see the earlier blog post. A specific interaction can then be visualized by a stratified PDP.

    Friedman’s H statistics: Left: BMI and age are the two features with clearly strongest interactions. Right: Their pairwise interaction explains about 10% of their joint effect variability.
    Stratified PDP: The strong interaction between age and BMI is clearly visible. A high BMI makes the age effect on diabetes stronger.

    SHAP

    What insights does a SHAP analysis bring?

    We will crunch slow exact permutation SHAP values via kernelshap::permshap(). If we had more features, we could switch to

    • kernelshap::kernelshap()
    • Brandon Greenwell’s {fastshap}, or to the
    • {treeshap} package of my colleages from TU Warsaw.
    set.seed(1)
    X_explain <- train[sample(1:nrow(train), 1000), xvars]
    X_background <- train[sample(1:nrow(train), 200), ]
    
    system.time(  # 10 minutes
      shap_values <- permshap(model, X = X_explain, bg_X = X_background, pred_fun = pf)
    )
    shap_values <- shapviz(shap_values)
    shap_values  # 'shapviz' object representing 1000 x 6 SHAP matrix
    saveRDS(shap_values, file = "shap_values.rds")
    # shap_values <- readRDS("shap_values.rds")
    
    sv_importance(shap_values, show_numbers = TRUE)
    sv_importance(shap_values, kind = "bee")
    sv_dependence(shap_values, v = xvars) &
      ylim(-0.14, 0.24) &
      ylab("Probability")

    SHAP importance

    SHAP importance: On average, the age increases or decreases the diabetes probability by 4.7% etc. In this case, the top three features are the same as in permutation importance.

    SHAP “summary” plot

    SHAP “summary” plot: Additionally to the bar plot, we see that higher age, higher BMI, hypertension, smoking, males, and having a heart disease are associated with higher diabetes risk.

    SHAP dependence plots

    SHAP dependence plots: We see similar shapes as in the PDPs. Thanks to the vertical scatter, we can, e.g., spot that the BMI effect strongly depends on the age. As in the PDPs, we have selected a common vertical scale to also see the effect strength.

    Final words

    • {hstats}, {kernelshap} and {shapviz} can explain any model with XAI methods like permutation importance, PDPs, Friedman’s H, and SHAP. This, obviously, also includes models developed with {tidymodels}.
    • They would actually even work for multi-output models, e.g., classification with more than two categories.
    • Studying a blackbox with XAI methods is always worth the effort, even if the methods have their issues. I.e., an imperfect explanation is still better than no explanation.
    • Model-agnostic SHAP takes a little bit of time, but it is usually worth the effort.

    The full R script

  • Permutation SHAP versus Kernel SHAP

    SHAP is the predominant way to interpret black-box ML models, especially for tree-based models with the blazingly fast TreeSHAP algorithm.

    For general models, two slower SHAP algorithms exist:

    1. Permutation SHAP (Štrumbelj and Kononenko, 2010)
    2. Kernel SHAP (Lundberg and Lee, 2017)

    Kernel SHAP was introduced in [2] as an approximation to permutation SHAP.

    The 0.4.0 CRAN release of our {kernelshap} package now contains an exact permutation SHAP algorithm for up to 14 features, and thus it becomes easy to make experiments between the two approaches.

    Some initial statements about permutation SHAP and Kernel SHAP

    1. Exact permutation SHAP and exact Kernel SHAP have the same computational complexity.
    2. Technically, exact Kernel SHAP is still an approximation of exact permutation SHAP, so you should prefer the latter.
    3. Kernel SHAP assumes feature independence. Since features are never independent in practice: does this mean we should never use Kernel SHAP?
    4. Kernel SHAP can be calculated almost exactly for any number of features, while permutation SHAP approximations get more and more inprecise when the number of features gets too large.

    Simulation 1

    We will first work with the iris data because it has extremely strong correlations between features. To see the impact of having models with and without interactions, we work with a random forest model of increasing tree depth. Depth 1 means no interactions, depth 2 means pairwise interactions etc.

    library(kernelshap)
    library(ranger)
    
    differences <- numeric(4)
    
    set.seed(1)
    
    for (depth in 1:4) {
      fit <- ranger(
        Sepal.Length ~ ., 
        mtry = 3,
        data = iris, 
        max.depth = depth
      )
      ps <- permshap(fit, iris[2:5], bg_X = iris)
      ks <- kernelshap(fit, iris[2:5], bg_X = iris)
      differences[depth] <- mean(abs(ks$S - ps$S))
    }
    
    differences  # for tree depth 1, 2, 3, 4
    # 5.053249e-17 9.046443e-17 2.387905e-04 4.403375e-04
    
    # SHAP values of first two rows with tree depth 4
    ps
    #      Sepal.Width Petal.Length Petal.Width      Species
    # [1,]  0.11377616   -0.7130647  -0.1956012 -0.004437022
    # [2,] -0.06852539   -0.7596562  -0.2259017 -0.006575266
      
    ks
    #      Sepal.Width Petal.Length Petal.Width      Species
    # [1,]  0.11463191   -0.7125194  -0.1951810 -0.006258208
    # [2,] -0.06828866   -0.7597391  -0.2259833 -0.006647530
    • Up to pairwise interactions (tree depth 2), the mean absolute difference between the two (150 x 4) SHAP matrices is 0.
    • Even for interactions of order three or higher, the differences are small. This is unexpected – in the end all iris features are strongly correlated!

    Simulation 2

    Let’s now use a different data set with more features: miami house price data. As modeling technique, we use XGBoost where we would normally use TreeSHAP. Also here, we increase tree depth from 1 to 3 for increasing interaction depth.

    library(xgboost)
    library(shapviz)
    
    colnames(miami) <- tolower(colnames(miami))
    miami$log_ocean <- log(miami$ocean_dist)
    x <- c("log_ocean", "tot_lvg_area", "lnd_sqfoot", "structure_quality", "age", "month_sold")
    
    # Train/valid split
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    
    y_train <- log(miami$sale_prc[ix])
    y_valid <- log(miami$sale_prc[-ix])
    X_train <- data.matrix(miami[ix, x])
    X_valid <- data.matrix(miami[-ix, x])
    
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    # Fit via early stopping (depth 1 to 3)
    differences <- numeric(3)
    
    for (i in 1:3) {
      fit <- xgb.train(
        params = list(learning_rate = 0.15, objective = "reg:squarederror", max_depth = i),
        data = dtrain,
        watchlist = list(valid = dvalid),
        early_stopping_rounds = 20,
        nrounds = 1000,
        callbacks = list(cb.print.evaluation(period = 100))
      )
      ps <- permshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500))
      ks <- kernelshap(fit, X = head(X_valid, 500), bg_X = head(X_valid, 500))
      differences[i] <- mean(abs(ks$S - ps$S))
    }
    differences # for tree depth 1, 2, 3
    # 2.904010e-09 5.158383e-09 6.586577e-04
    
    # SHAP values of top two rows for tree depth 3
    ps
    # log_ocean tot_lvg_area lnd_sqfoot structure_quality        age  month_sold
    # 0.2224359   0.04941044  0.1266136         0.1360166 0.01036866 0.005557032
    # 0.3674484   0.01045079  0.1192187         0.1180312 0.01426247 0.005465283
    
    ks
    # log_ocean tot_lvg_area lnd_sqfoot structure_quality        age  month_sold
    # 0.2245202  0.049520308  0.1266020         0.1349770 0.01142703 0.003355770
    # 0.3697167  0.009575195  0.1198201         0.1168738 0.01544061 0.003450425

    Again the same picture as with iris: Essentially no differences for interactions up to order two, and only small differences with interactions of higher order.

    Wrap-Up

    1. Use kernelshap::permshap() to crunch exact permutation SHAP values for models with not too many features.
    2. In real-world applications, exact Kernel SHAP and exact permutation SHAP start to differ (slightly) with models containing interactions of order three or higher.
    3. Since Kernel SHAP can be calculated almost exactly also for many features, it remains an excellent way to crunch SHAP values for arbitrary models.

    What is your experience?

    The R code is here.

  • Interactions – where are you?

    This question sends shivers down the poor modelers spine…

    The {hstats} R package introduced in our last post measures their strength using Friedman’s H-statistics, a collection of statistics based on partial dependence functions.

    On Github, the preview version of {hstats} 1.0.0 out – I will try to bring it to CRAN in about one week (October 2023). Until then, try it via devtools::install_github("mayer79/hstats")

    The current version offers:

    • H statistics per feature, feature pair, and feature triple
    • Multivariate predictions at no additional cost
    • A convenient API
    • Other important tools from explainable ML:
      • performance calculations
      • permutation importance (e.g., to select features for calculating H-statistics)
      • partial dependence plots (including grouping, multivariate, multivariable)
      • individual conditional expectations (ICE)
    • Case-weights are available for all methods, which is important, e.g., in insurance applications
    • The option for fast quantile approximation of H-statistics

    This post has two parts:

    1. Example with house-prices and XGBoost
    2. Naive benchmark against {iml}, {DALEX}, and my old {flashlight}.

    1. Example

    Let’s model logarithmic sales prices of houses sold in Miami Dade County, a dataset prepared by Prof. Dr. Steven Bourassa, and available in {shapviz}. We use XGBoost with interaction constraints to provide a model additive in all structure information, but allowing for interactions between latitude/longitude for a flexible representation of geographic effects.

    The following code prepares the data, splits the data into train and validation, and then fits an XGBoost model.

    library(hstats)
    library(shapviz)
    library(xgboost)
    library(ggplot2)
    
    # Data preparation
    colnames(miami) <- tolower(colnames(miami))
    miami <- transform(miami, log_price = log(sale_prc))
    x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
           "structure_quality", "age", "month_sold")
    coord <- c("longitude", "latitude")
    
    # Modeling
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    train <- data.frame(miami[ix, ])
    valid <- data.frame(miami[-ix, ])
    y_train <- train$log_price
    y_valid <- valid$log_price
    X_train <- data.matrix(train[x])
    X_valid <- data.matrix(valid[x])
    
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    ic <- c(
      list(which(x %in% coord) - 1),
      as.list(which(!x %in% coord) - 1)
    )
    
    # Fit via early stopping
    fit <- xgb.train(
      params = list(
        learning_rate = 0.15, 
        objective = "reg:squarederror", 
        max_depth = 5,
        interaction_constraints = ic
      ),
      data = dtrain,
      watchlist = list(valid = dvalid),
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )

    Now it is time for a compact analysis with {hstats} to interpret the model:

    average_loss(fit, X = X_valid, y = y_valid)  # 0.0247 MSE -> 0.157 RMSE
    
    perm_importance(fit, X = X_valid, y = y_valid) |> 
      plot()
    
    # Or combining some features
    v_groups <- list(
      coord = c("longitude", "latitude"),
      size = c("lnd_sqfoot", "tot_lvg_area"),
      condition = c("age", "structure_quality")
    )
    perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |> 
      plot()
    
    H <- hstats(fit, v = x, X = X_valid)
    H
    plot(H)
    plot(H, zero = FALSE)
    h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE)
    partial_dep(fit, v = "tot_lvg_area", X = X_valid) |> 
      plot()
    partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |> 
      plot(show_points = FALSE)
    plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid))
    plot(ii, center = TRUE)
    
    # Spatial plots
    g <- unique(X_valid[, coord])
    pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
    plot(pp, d2_geom = "point", alpha = 0.5, size = 1) + 
      coord_equal()
    
    # Takes some seconds because it generates the last plot per structure quality
    partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
      plot(pp, d2_geom = "point", alpha = 0.5) +
      coord_equal()
    )

    Results summarized by plots

    Permutation importance

    Figure 1: Permutation importance (4 repetitions) on the validation data. Error bars show standard errors of the estimated increase in MSE from shuffling feature values.
    Figure 2: Feature groups can be shuffled together – accounting for issues of permutation importance with highly correlated features

    H-Statistics

    Let’s now move on to interaction statistics.

    Figure 3: Overall and pairwise H-statistics. Overall H^2 gives the proportion of prediction variability explained by all interactions of the feature. By default, {hstats} picks the five features with largest H^2 and calculates their pairwise H^2. This explains why not all 21 feature pairs appear in the figure on the right-hand side. Pairwise H^2 is differently scaled than overall H^2: It gives the proportion of joint effect variability of the two features explained by their interaction.
    Figure 4: Use “zero = FALSE” to drop variable (pairs) with value 0.

    PDPs and ICEs

    Figure 5: A partial dependence plot of living area.
    Figure 6: Stratification shows indeed: no interactions between structure quality and living area.
    Figure 7: ICE plots also show no interations with any other feature. The interaction constraints of XGBoost did a good job.
    Figure 8: This two-dimensional PDP evaluated over all unique coordinates shows a realistic profile of house prices in Miami Dade County (mind the log scale).
    Figure 8: Same, but grouped by structure quality (5 is best). Since there is no interaction between location and structure quality, the plots are just shifted versions of each other. (You can’t really see it on the plots.)

    Naive Benchmark

    All methods in {hstats} are optimized for speed. But how fast are they compared to other implementations? Note that: this is just a simple benchmark run on a Windows notebook with Intel i7-8650U CPU.

    Note that {iml} offers a parallel backend, but we could not make it run with XGBoost and Windows. Let me know how fast it is using parallelism and Linux!

    Setup + benchmark on permutation importance

    Always using the full validation dataset and 10 repetitions.

    library(iml)  # Might benefit of multiprocessing, but on Windows with XGB models, this is not easy
    library(DALEX)
    library(ingredients)
    library(flashlight)
    library(bench)
    
    set.seed(1)
    
    # iml
    predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
    mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid, 
                         predict.function = predf)
    
    # DALEX
    ex <- DALEX::explain(fit, data = X_valid, y = y_valid)
    
    # flashlight (my slightly old fashioned package)
    fl <- flashlight(
      model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm"
    )
      
    # Permutation importance: 10 repeats over full validation data (~2700 rows)
    bench::mark(
      iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"),
      dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf),
      flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10),
      hstats = perm_importance(fit, X = X_valid, y = y_valid, m_rep = 10, verbose = FALSE),
      check = FALSE,
      min_iterations = 3
    )
    
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.58s    1.58s     0.631   209.4MB    2.73      3    13      4.76s
    # dalex      566.21ms 586.91ms     1.72     34.6MB    0.572     3     1      1.75s
    # flashlight 587.03ms 613.15ms     1.63     27.1MB    1.63      3     3      1.84s
    # hstats     353.78ms 360.57ms     2.79     27.2MB    0         3     0      1.08s

    {hstats} is about 30% faster as the second, {DALEX}.

    Partial dependence

    Here, we study the time for crunching partial dependence of a continuous feature and a discrete feature.

    # Partial dependence (cont)
    v <- "tot_lvg_area"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"),
      dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50),
      hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.11s    1.13s     0.887   376.3MB     3.84     3    13      3.38s
    # dalex      782.13ms 783.08ms     1.24    192.8MB     2.90     3     7      2.41s
    # flashlight 367.73ms  372.5ms     2.68     67.9MB     2.68     3     3      1.12s
    # hstats     220.88ms  222.5ms     4.50     14.2MB     0        3     0   666.33ms
     
    # Partial dependence (discrete)
    v <- "structure_quality"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5),
      dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf),
      hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml            90ms     96ms     10.6    13.29MB     7.06     3     2      283ms
    # dalex       170.6ms  174.4ms      5.73   20.55MB     2.87     2     1      349ms
    # flashlight   40.8ms   43.8ms     23.1     6.36MB     2.10    11     1      476ms
    # hstats       23.5ms   24.4ms     40.6     1.53MB     2.14    19     1      468ms

    {hstats} is 1.5 to 2 times faster than {flashlight}, and about four times as fast as the other packages. It’s memory foodprint is much lower.

    H-statistics

    How fast can overall H-statistics be computed? How fast can it do pairwise calculations?

    {DALEX} does not offer these statistics yet. {iml} was the first model-agnostic implementation of H-statistics I am aware of. It uses quantile approximation by default, but we purposely force it to calculate exact, in order to compare the numbers. Thus, we made it slower than it actually is.

    # H-Stats -> we use a subset of 500 rows
    X_v500 <- X_valid[1:500, ]
    mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
    fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))
    
    # iml  # 225s total, using slow exact calculations
    system.time(  # 90s
      iml_overall <- Interaction$new(mod500, grid.size = 500)
    )
    system.time(  # 135s for all combinations of latitude
      iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
    )
    
    # flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
    system.time(  # 12s
      fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
    )
    system.time(  # 2s
      fl_pairwise <- light_interaction(
        fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
      )
    )
    
    # hstats: 3s total
    system.time({
      H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
      hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
      hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE)
    }
    )
    
    # Overall statistics correspond exactly
    iml_overall$results |> filter(.interaction > 1e-6)
    #     .feature .interaction
    # 1:  latitude    0.2458269
    # 2: longitude    0.2458269
    
    fl_overall$data |> subset(value > 0, select = c(variable, value))
    #   variable  value
    # 1 latitude  0.246
    # 2 longitude 0.246
    
    hstats_overall
    # longitude  latitude 
    # 0.2458269 0.2458269 
    
    # Pairwise results match as well
    iml_pairwise$results |> filter(.interaction > 1e-6)
    #              .feature .interaction
    # 1: longitude:latitude    0.3942526
    
    fl_pairwise$data |> subset(value > 0, select = c(variable, value))
    # latitude:longitude 0.394
    
    hstats_pairwise
    # latitude:longitude 
    # 0.3942526 
    • {hstats} is about four times as fast as {flashlight}.
    • Since one often want to study relative and absolute H-statistics, in practice, the speed-up would be about a factor of eight.
    • In multi-classification/multi-output settings with m categories, the speed-up would be even m times larger.
    • The fast approximation via quantile binning is again a factor of four faster. The difference would diminish if we would calculate many pairwise or three-way H-statistics.
    • Forcing all three packages to calculate exact statistics, all results match.

    Wrap-Up

    • {hstats} is much faster than other XAI packages, at least in our use-case. This includes H-statistics, permutation importance, and partial dependence. Note that making good benchmarks is not my strength, so forgive any bias in the results.
    • The memory foodprint is lower as well.
    • With multivariate output, the potential is even larger.
    • H-Statistics match other implementations.

    Try it out!

    The full R code in one piece is here.

  • It’s the interactions

    What makes a ML model a black-box? It is the interactions. Without any interactions, the ML model is additive and can be exactly described.

    Studying interaction effects of ML models is challenging. The main XAI approaches are:

    1. Looking at ICE plots, stratified PDP, and/or 2D PDP.
    2. Study vertical scatter in SHAP dependence plots, or even consider SHAP interaction values.
    3. Check partial-dependence based H-statistics introduced in Friedman and Popescu (2008), or related statistics.

    This post is mainly about the third approach. Its beauty is that we get information about all interactions. The downside: it is as good/bad as partial dependence functions. And: the statistics are computationally very expensive to compute (of order n^2).

    Different R packages offer some of these H-statistics, including {iml}, {gbm}, {flashlight}, and {vivid}. They all have their limitations. This is why I wrote the new R package {hstats}:

    • It is very efficient.
    • Has a clean API. DALEX explainers and meta-learners (mlr3, Tidymodels, caret) work out-of-the-box.
    • Supports multivariate predictions, including classification models.
    • Allows to calculate unnormalized H-statistics. They help to compare pairwise and three-way statistics.
    • Contains fast multivariate ICE/PDPs with optional grouping variable.

    In Python, there is the very interesting project artemis. I will write a post on it later.

    Statistics supported by {hstats}

    Furthermore, a global measure of non-additivity (proportion of prediction variability unexplained by main effects), and a measure of feature importance is available. For technical details and references, check the following pdf or github.

    Classification example

    Let’s fit a probability random forest on iris species.

    library(ranger)
    library(ggplot2)
    library(hstats)
    
    v <- setdiff(colnames(iris), "Species")
    fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1)
    s <- hstats(fit, v = v, X = iris)  # 8 seconds run-time
    s
    # Proportion of prediction variability unexplained by main effects of v:
    #      setosa  versicolor   virginica 
    # 0.002705945 0.065629375 0.046742035
    
    plot(s, normalize = FALSE, squared = FALSE) +
      ggtitle("Unnormalized statistics") +
      scale_fill_viridis_d(begin = 0.1, end = 0.9)
    
    ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> 
      plot(center = TRUE) +
      ggtitle("Centered ICE plots")
    
    Unnormalized H-statistics, i.e., values are roughly on the scale of the predictions (here: probabilities).
    Centered ICE plots per class.

    Interpretation:

    • The features with strongest interactions are Petal Length and Petal Width. These interactions mainly affect species “virginica” and “versicolor”. The effect for “setosa” is almost additive.
    • Unnormalized pairwise statistics show that the strongest absolute interaction happens indeed between Petal Length and Petal Width.
    • The centered ICE plots shows how the interaction manifests: The effect of Petal Length heavily depends on Petal Width, except for species “setosa”. Would a SHAP analysis show the same?

    DALEX example

    Here, we consider a random forest regression on “Sepal.Length”.

    library(DALEX)
    library(ranger)
    library(hstats)
    
    set.seed(1)
    
    fit <- ranger(Sepal.Length ~ ., data = iris)
    ex <- explain(fit, data = iris[-1], y = iris[, 1])
    
    s <- hstats(ex)  # 2 seconds
    s  # Non-additivity index 0.054
    plot(s)
    plot(ice(ex, v = "Sepal.Width", BY = "Petal.Width"), center = TRUE)
    
    H-statistics
    Centered ICE plot of strongest relative interactions.

    Interpretation

    • Petal Length and Width show the strongest overall associations. Since we are considering normalized statistics, we can say: “About 3.5% of prediction variability comes from interactions with Petal Length”.
    • The strongest relative pairwise interaction happens between Sepal Width and Petal Width: Again, because we study normalized H-statistics, we can say: “About 4% of total prediction variability of the two features Sepal Width and Petal Width can be attributed to their interactions.”
    • Overall, all interactions explain only about 5% of prediction variability (see text output).

    Try it out!

    The complete R script can be found here. More examples and background can be found on the Github page of the project.

  • Model Diagnostics in Python

    🚀Version 1.0.0 of the new Python package for model-diagnostics was just released on PyPI. If you use (machine learning or statistical or other) models to predict a mean, median, quantile or expectile, this library offers tools to assess the calibration of your models and to compare and decompose predictive model performance scores.🚀

    pip install model-diagnostics

    After having finished our paper (or better: user guide) “Model Comparison and Calibration Assessment: User Guide for Consistent Scoring Functions in Machine Learning and Actuarial Practice” last year, I realised that there is no Python package that supports the proposed diagnostic tools (which are not completely new). Most of the required building blocks are there, but putting them together to get a result amounts quickly to a large amount of code. Therefore, I decided to publish a new package.

    By the way, I really never wanted to write a plotting library. But it turned out that arranging results until they are ready to be visualised amounts to quite a large part of the source code. I hope this was worth the effort. Your feedback is very welcome, either here in the comments or as feature request or bug report under https://github.com/lorentzenchr/model-diagnostics/issues.

    For a jump start, I recommend to go directly to the two examples:

    To give a glimpse of the functionality, here are some short code snippets.

    from model_diagnostics.calibration import compute_bias
    from model_diagnostics.calibration import plot_reliability_diagram
    
    
    y_obs = list(range(10))
    y_pred = [2, 1, 3, 3, 6, 8, 5, 5, 8, 9.]
    plot_reliability_diagram(
        y_obs=y_obs,
        y_pred=y_pred,
        n_bootstrap=1000,
        confidence_level=0.9,
    )
    compute_bias(y_obs=y_obs, y_pred=y_pred)
    bias_meanbias_countbias_weightsbias_stderrp_value
    f64u32f64f64f64
    0.51010.00.4772610.322121
    from model_diagnostics.scoring import SquaredError, decompose
    
    
    decompose(
        y_obs=y_obs,
        y_pred=y_pred,
        scoring_function=SquaredError(),
    )
    miscalibrationdiscriminationuncertaintyscore
    f64f64f64f64
    1.2833337.2333338.252.3

    This score decomposition is additive (and unique):

    \begin{equation*}
    \mathrm{score} = \mathrm{miscalibration} - \mathrm{discrimination} + \mathrm{uncertainty}
    \end{equation*}

    As usual, the code snippets are collected in a notebook: https://github.com/lorentzenchr/notebooks/blob/master/blogposts/2023-07-16%20model-diagnostics.ipynb.

  • Geographic SHAP

    Lost in Translation between R and Python 10

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    This post is heavily based on the new {shapviz} vignette.

    Setting

    Besides other features, a model with geographic components contains features like

    • latitude and longitude,
    • postal code, and/or
    • other features that depend on location, e.g., distance to next restaurant.

    Like any feature, the effect of a single geographic feature can be described using SHAP dependence plots. However, studying the effect of latitude (or any other location dependent feature) alone is often not very illuminating – simply due to strong interaction effects and correlations with other geographic features.

    That’s where the additivity of SHAP values comes into play: The sum of SHAP values of all geographic components represent the total geographic effect, and this sum can be visualized as a heatmap or 3D scatterplot against latitude/longitude (or any other geographic representation).

    A first example

    For illustration, we will use a beautiful house price dataset containing information on about 14’000 houses sold in 2016 in Miami-Dade County. Some of the columns are as follows:

    • SALE_PRC: Sale price in USD: Its logarithm will be our model response.
    • LATITUDE, LONGITUDE: Coordinates
    • CNTR_DIST: Distance to central business district
    • OCEAN_DIST: Distance (ft) to the ocean
    • RAIL_DIST: Distance (ft) to the next railway track
    • HWY_DIST: Distance (ft) to next highway
    • TOT_LVG_AREA: Living area in square feet
    • LND_SQFOOT: Land area in square feet
    • structure_quality: Measure of building quality (1: worst to 5: best)
    • age: Age of the building in years

    (Italic features are geographic components.) For more background on this dataset, see Mayer et al [2].

    We will fit an XGBoost model to explain log(price) as a function of lat/long, size, and quality/age.

    devtools::install_github("ModelOriented/shapviz", dependencies = TRUE)
    library(xgboost)
    library(ggplot2)
    library(shapviz)  # Needs development version 0.9.0 from github
    
    head(miami)
    
    x_coord <- c("LATITUDE", "LONGITUDE")
    x_nongeo <- c("TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age")
    x <- c(x_coord, x_nongeo)
    
    # Train/valid split
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    X_train <- data.matrix(miami[ix, x])
    X_valid <- data.matrix(miami[-ix, x])
    y_train <- log(miami$SALE_PRC[ix])
    y_valid <- log(miami$SALE_PRC[-ix])
    
    # Fit XGBoost model with early stopping
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    params <- list(learning_rate = 0.2, objective = "reg:squarederror", max_depth = 5)
    
    fit <- xgb.train(
      params = params, 
      data = dtrain, 
      watchlist = list(valid = dvalid), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    %load_ext lab_black
    
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import fetch_openml
    
    df = fetch_openml(data_id=43093, as_frame=True)
    X, y = df.data, np.log(df.target)
    X.head()
    
    # Data split and model
    from sklearn.model_selection import train_test_split
    import xgboost as xgb
    
    x_coord = ["LONGITUDE", "LATITUDE"]
    x_nongeo = ["TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age"]
    x = x_coord + x_nongeo
    
    X_train, X_valid, y_train, y_valid = train_test_split(
        X[x], y, test_size=0.2, random_state=30
    )
    
    # Fit XGBoost model with early stopping
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dvalid = xgb.DMatrix(X_valid, label=y_valid)
    
    params = dict(learning_rate=0.2, objective="reg:squarederror", max_depth=5)
    
    fit = xgb.train(
        params=params,
        dtrain=dtrain,
        evals=[(dvalid, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    

    SHAP dependence plots

    Let’s first study selected SHAP dependence plots, evaluated on the validation dataset with around 2800 observations. Note that we could as well use the training data for this purpose, but it is a bit large.

    sv <- shapviz(fit, X_pred = X_valid)
    sv_dependence(
      sv, 
      v = c("TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"), 
      alpha = 0.2
    )
    import shap
    
    xgb_explainer = shap.Explainer(fit)
    shap_values = xgb_explainer(X_valid)
    
    v = ["TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"]
    shap.plots.scatter(shap_values[:, v], color=shap_values[:, v])
    SHAP dependence plots of selected features (Python output).

    Total coordindate effect

    And now the two-dimensional plot of the sum of SHAP values:

    sv_dependence2D(sv, x = "LONGITUDE", y = "LATITUDE") +
      coord_equal()
    shap_coord = shap_values[:, x_coord]
    plt.scatter(*list(shap_coord.data.T), c=shap_coord.values.sum(axis=1), s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    Sum of SHAP values on color scale against coordinates (Python output).

    The last plot gives a good impression on price levels, but note:

    1. Since we have modeled logarithmic prices, the effects are on relative scale (0.1 means about 10% above average).
    2. Due to interaction effects with non-geographic components, the location effects might depend on features like living area. This is not visible in above plot. We will modify the model now to improve this aspect.

    Two modifications

    We will now change above model in two ways, not unlike the model in Mayer et al [2].

    1. We will use additional geographic features like distance to railway track or to the ocean.
    2. We will use interaction constraints to allow only interactions between geographic features.

    The second step leads to a model that is additive in each non-geographic component and also additive in the combined location effect. According to the technical report of Mayer [1], SHAP dependence plots of additive components in a boosted trees model are shifted versions of corresponding partial dependence plots (evaluated at observed values). This allows a “Ceteris Paribus” interpretation of SHAP dependence plots of corresponding components.

    # Extend the feature set
    more_geo <- c("CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST")
    x2 <- c(x, more_geo)
    
    X_train2 <- data.matrix(miami[ix, x2])
    X_valid2 <- data.matrix(miami[-ix, x2])
    
    dtrain2 <- xgb.DMatrix(X_train2, label = y_train)
    dvalid2 <- xgb.DMatrix(X_valid2, label = y_valid)
    
    # Build interaction constraint vector
    ic <- c(
      list(which(x2 %in% c(x_coord, more_geo)) - 1),
      as.list(which(x2 %in% x_nongeo) - 1)
    )
    
    # Modify parameters
    params$interaction_constraints <- ic
    
    fit2 <- xgb.train(
      params = params, 
      data = dtrain2, 
      watchlist = list(valid = dvalid2), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    
    # SHAP analysis
    sv2 <- shapviz(fit2, X_pred = X_valid2)
    
    # Two selected features: Thanks to additivity, structure_quality can be read as 
    # Ceteris Paribus
    sv_dependence(sv2, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
    
    # Total geographic effect (Ceteris Paribus thanks to additivity)
    sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
      coord_equal()
    # Extend the feature set
    more_geo = ["CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST"]
    x2 = x + more_geo
    
    X_train2, X_valid2 = train_test_split(X[x2], test_size=0.2, random_state=30)
    
    dtrain2 = xgb.DMatrix(X_train2, label=y_train)
    dvalid2 = xgb.DMatrix(X_valid2, label=y_valid)
    
    # Build interaction constraint vector
    ic = [x_coord + more_geo, *[[z] for z in x_nongeo]]
    
    # Modify parameters
    params["interaction_constraints"] = ic
    
    fit2 = xgb.train(
        params=params,
        dtrain=dtrain2,
        evals=[(dvalid2, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    
    # SHAP analysis
    xgb_explainer2 = shap.Explainer(fit2)
    shap_values2 = xgb_explainer2(X_valid2)
    
    v = ["structure_quality", "LONGITUDE"]
    shap.plots.scatter(shap_values2[:, v], color=shap_values2[:, v])
    
    # Total location effect
    shap_coord2 = shap_values2[:, x_coord]
    c = shap_values2[:, x_coord + more_geo].values.sum(axis=1)
    plt.scatter(*list(shap_coord2.data.T), c=c, s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    SHAP dependence plots of an additive feature (structure quality, no vertical scatter per unique feature value) and one of the geographic features (Python output).
    Sum of all geographic features (color) against coordinates. There are no interactions to non-geographic features, so the effect can be read Ceteris Paribus (Python output).

    Again, the resulting total geographic effect looks reasonable.

    Wrap-Up

    • SHAP values of all geographic components in a model can be summed up and plotted on the color scale against coordinates (or some other geographic representation). This gives a lightning fast impression of the location effects.
    • Interaction constraints between geographic and non-geographic features lead to Ceteris Paribus interpretation of total geographic effects.

    The Python and R notebooks can be found here:

    References

    1. Mayer, Michael. 2022. “SHAP for Additively Modeled Features in a Boosted Trees Model.” https://arxiv.org/abs/2207.14490.
    2. Mayer, Michael, Steven C. Bourassa, Martin Hoesli, and Donato Flavio Scognamiglio. 2022. “Machine Learning Applications to Land and Structure Valuation.” Journal of Risk and Financial Management.

  • SHAP + XGBoost + Tidymodels = LOVE

    In this recent post, we have explained how to use Kernel SHAP for interpreting complex linear models. As plotting backend, we used our fresh CRAN package “shapviz“.

    “shapviz” has direct connectors to a couple of packages such as XGBoost, LightGBM, H2O, kernelshap, and more. Multiple times people asked me how to combine shapviz when the XGBoost model was fitted with Tidymodels. The workflow was not 100% clear to me as well, but the answer is actually very simple, thanks to Julia’s post where the plots were made with SHAPforxgboost, another cool package for visualization of SHAP values.

    Example with shiny diamonds

    Step 1: Preprocessing

    We first write the data preprocessing recipe and apply it to the data rows that we want to explain. In our case, its 1000 randomly sampled diamonds.

    library(tidyverse)
    library(tidymodels)
    library(shapviz)
    
    # Integer encode factors
    dia_recipe <- diamonds %>%
      recipe(price ~ carat + cut + clarity + color) %>% 
      step_integer(all_nominal())
    
    # Will explain THIS dataset later
    set.seed(2)
    dia_small <- diamonds[sample(nrow(diamonds), 1000), ]
    dia_small_prep <- bake(
      prep(dia_recipe), 
      has_role("predictor"),
      new_data = dia_small, 
      composition = "matrix"
    )
    head(dia_small_prep)
    
    #     carat cut clarity color
    #[1,]  0.57   5       4     4
    #[2,]  1.01   5       2     1
    #[3,]  0.45   1       4     3
    #[4,]  1.04   4       6     5
    #[5,]  0.90   3       6     4
    #[6,]  1.20   3       4     6
    

    Step 2: Fit Model

    The next step is to tune and build the model. For simplicity, we skipped the tuning part. Bad, bad 🙂

    # Just for illustration - in practice needs tuning!
    xgboost_model <- boost_tree(
      mode = "regression",
      trees = 200,
      tree_depth = 5,
      learn_rate = 0.05,
      engine = "xgboost"
    )
    
    dia_wf <- workflow() %>%
      add_recipe(dia_recipe) %>%
      add_model(xgboost_model)
    
    fit <- dia_wf %>%
      fit(diamonds)

    Step 3: SHAP Analysis

    We now need to call shapviz() on the fitted model. In order to have neat interpretations with the original factor labels, we not only pass the prediction data prepared in Step 1 via bake(), but also the original data structure.

    shap <- shapviz(extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small)
    
    sv_importance(shap, kind = "both", show_numbers = TRUE)
    sv_dependence(shap, "carat", color_var = "auto")
    sv_dependence(shap, "clarity", color_var = "auto")
    sv_force(shap, row_id = 1)
    sv_waterfall(shap, row_id = 1)
    
    Variable importance plot overlaid with SHAP summary beeswarms
    Dependence plot for carat. Note that clarity is shown with original labels, not only integers.
    Dependence plot for clarity. Note again that the x-scale uses the original factor levels, not the integer encoded values.
    Force plot of the first observation
    Waterfall plot for the first observation

    Summary

    Making SHAP analyses with XGBoost Tidymodels is super easy.

    The complete R script can be found here.

  • Dplyr-style without dplyr

    One of the reasons why we love the “dplyr” package: it plays so well together with the forward pipe operator `%>%` from the “magrittr” package. Actually, it is not a coincidence that both packages were released quite at the same time, in 2014.

    What does the pipe do? It puts the object on its left as the first argument into the function on its right: iris %>% head() is a funny way of writing head(iris). It helps to avoid long function chains like f(g(h(x))), or repeated assignments.

    In 2021 and version 4.1, R has received its native forward pipe operator |> so that we can write nice code like this:

    Imagine this without pipe…

    Since version 4.2, the piped object can be referenced by the underscore _, but just once for now, see an example below.

    To use the native pipe via CTRL-SHIFT-M in Posit/RStudio, tick this:

    Combined with the many great functions from the standard distribution of R, we can get a real “dplyr” feeling without even loading dplyr. Don’t get me wrong: I am a huge fan of the whole Tidyverse! But it is a great way to learn “Standard R”.

    Data chains

    Here a small selection of standard functions playing well together with the pipe: They take a data frame and return a modified data frame:

    • subset(): Select rows and columns of data frame
    • transform(): Add or overwrite columns in data frame
    • aggregate(): Grouped calculations
    • rbind(), cbind(): Bind rows/columns of data frame/matrix
    • merge(): Join data frames by key
    • head(), tail(): First/last few elements of object
    • reshape(): Transposition/Reshaping of data frame (no, I don’t understand the interface)
    library(ggplot2)  # Need diamonds
    
    # What does the native pipe do?
    quote(diamonds |> head())
    
    # OUTPUT
    # head(diamonds)
    
    # Grouped statistics
    diamonds |> 
      aggregate(cbind(price, carat) ~ color, FUN = mean)
    
    # OUTPUT
    #   color    price     carat
    # 1     D 3169.954 0.6577948
    # 2     E 3076.752 0.6578667
    # 3     F 3724.886 0.7365385
    # 4     G 3999.136 0.7711902
    # 5     H 4486.669 0.9117991
    # 6     I 5091.875 1.0269273
    # 7     J 5323.818 1.1621368
    
    # Join back grouped stats to relevant columns
    diamonds |> 
      subset(select = c(price, color, carat)) |> 
      transform(price_per_color = ave(price, color)) |> 
      head()
    
    # OUTPUT
    #   price color carat price_per_color
    # 1   326     E  0.23        3076.752
    # 2   326     E  0.21        3076.752
    # 3   327     E  0.23        3076.752
    # 4   334     I  0.29        5091.875
    # 5   335     J  0.31        5323.818
    # 6   336     J  0.24        5323.818
    
    # Plot transformed values
    diamonds |> 
      transform(
        log_price = log(price),
        log_carat = log(carat)
      ) |> 
      plot(log_price ~ log_carat, col = "chartreuse4", pch = ".", data = _)
    A simple scatterplot

    The plot does not look quite as sexy as “ggplot2”, but its a start.

    Other chains

    The pipe not only works perfectly with functions that modify a data frame. It also shines with many other functions often applied in a nested way. Here two examples:

    # Distribution of color within clarity
    diamonds |> 
      subset(select = c(color, clarity)) |> 
      table() |> 
      prop.table(margin = 2) |> 
      addmargins(margin = 1) |> 
      round(3)
    
    # OUTPUT
    # clarity
    # color      I1   SI2   SI1   VS2   VS1  VVS2  VVS1    IF
    #     D   0.057 0.149 0.159 0.138 0.086 0.109 0.069 0.041
    #     E   0.138 0.186 0.186 0.202 0.157 0.196 0.179 0.088
    #     F   0.193 0.175 0.163 0.180 0.167 0.192 0.201 0.215
    #     G   0.202 0.168 0.151 0.191 0.263 0.285 0.273 0.380
    #     H   0.219 0.170 0.174 0.134 0.143 0.120 0.160 0.167
    #     I   0.124 0.099 0.109 0.095 0.118 0.072 0.097 0.080
    #     J   0.067 0.052 0.057 0.060 0.066 0.026 0.020 0.028
    #     Sum 1.000 1.000 1.000 1.000 1.000 1.000 1.000 1.000
    
    # Barplot from discrete column
    diamonds$color |> 
      table() |> 
      prop.table() |> 
      barplot(col = "chartreuse4", main = "Color")

    Wrap up

    • Piping is fun with and without dplyr.
    • It is a great motivation to learn standard R

    The complete R script can be found here.