Category: Programming

  • Interactions – where are you?

    This question sends shivers down the poor modelers spine…

    The {hstats} R package introduced in our last post measures their strength using Friedman’s H-statistics, a collection of statistics based on partial dependence functions.

    On Github, the preview version of {hstats} 1.0.0 out – I will try to bring it to CRAN in about one week (October 2023). Until then, try it via devtools::install_github("mayer79/hstats")

    The current version offers:

    • H statistics per feature, feature pair, and feature triple
    • Multivariate predictions at no additional cost
    • A convenient API
    • Other important tools from explainable ML:
      • performance calculations
      • permutation importance (e.g., to select features for calculating H-statistics)
      • partial dependence plots (including grouping, multivariate, multivariable)
      • individual conditional expectations (ICE)
    • Case-weights are available for all methods, which is important, e.g., in insurance applications
    • The option for fast quantile approximation of H-statistics

    This post has two parts:

    1. Example with house-prices and XGBoost
    2. Naive benchmark against {iml}, {DALEX}, and my old {flashlight}.

    1. Example

    Let’s model logarithmic sales prices of houses sold in Miami Dade County, a dataset prepared by Prof. Dr. Steven Bourassa, and available in {shapviz}. We use XGBoost with interaction constraints to provide a model additive in all structure information, but allowing for interactions between latitude/longitude for a flexible representation of geographic effects.

    The following code prepares the data, splits the data into train and validation, and then fits an XGBoost model.

    library(hstats)
    library(shapviz)
    library(xgboost)
    library(ggplot2)
    
    # Data preparation
    colnames(miami) <- tolower(colnames(miami))
    miami <- transform(miami, log_price = log(sale_prc))
    x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
           "structure_quality", "age", "month_sold")
    coord <- c("longitude", "latitude")
    
    # Modeling
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    train <- data.frame(miami[ix, ])
    valid <- data.frame(miami[-ix, ])
    y_train <- train$log_price
    y_valid <- valid$log_price
    X_train <- data.matrix(train[x])
    X_valid <- data.matrix(valid[x])
    
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    ic <- c(
      list(which(x %in% coord) - 1),
      as.list(which(!x %in% coord) - 1)
    )
    
    # Fit via early stopping
    fit <- xgb.train(
      params = list(
        learning_rate = 0.15, 
        objective = "reg:squarederror", 
        max_depth = 5,
        interaction_constraints = ic
      ),
      data = dtrain,
      watchlist = list(valid = dvalid),
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )

    Now it is time for a compact analysis with {hstats} to interpret the model:

    average_loss(fit, X = X_valid, y = y_valid)  # 0.0247 MSE -> 0.157 RMSE
    
    perm_importance(fit, X = X_valid, y = y_valid) |> 
      plot()
    
    # Or combining some features
    v_groups <- list(
      coord = c("longitude", "latitude"),
      size = c("lnd_sqfoot", "tot_lvg_area"),
      condition = c("age", "structure_quality")
    )
    perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |> 
      plot()
    
    H <- hstats(fit, v = x, X = X_valid)
    H
    plot(H)
    plot(H, zero = FALSE)
    h2_pairwise(H, zero = FALSE, squared = FALSE, normalize = FALSE)
    partial_dep(fit, v = "tot_lvg_area", X = X_valid) |> 
      plot()
    partial_dep(fit, v = "tot_lvg_area", X = X_valid, BY = "structure_quality") |> 
      plot(show_points = FALSE)
    plot(ii <- ice(fit, v = "tot_lvg_area", X = X_valid))
    plot(ii, center = TRUE)
    
    # Spatial plots
    g <- unique(X_valid[, coord])
    pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
    plot(pp, d2_geom = "point", alpha = 0.5, size = 1) + 
      coord_equal()
    
    # Takes some seconds because it generates the last plot per structure quality
    partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
      plot(pp, d2_geom = "point", alpha = 0.5) +
      coord_equal()
    )

    Results summarized by plots

    Permutation importance

    Figure 1: Permutation importance (4 repetitions) on the validation data. Error bars show standard errors of the estimated increase in MSE from shuffling feature values.
    Figure 2: Feature groups can be shuffled together – accounting for issues of permutation importance with highly correlated features

    H-Statistics

    Let’s now move on to interaction statistics.

    Figure 3: Overall and pairwise H-statistics. Overall H^2 gives the proportion of prediction variability explained by all interactions of the feature. By default, {hstats} picks the five features with largest H^2 and calculates their pairwise H^2. This explains why not all 21 feature pairs appear in the figure on the right-hand side. Pairwise H^2 is differently scaled than overall H^2: It gives the proportion of joint effect variability of the two features explained by their interaction.
    Figure 4: Use “zero = FALSE” to drop variable (pairs) with value 0.

    PDPs and ICEs

    Figure 5: A partial dependence plot of living area.
    Figure 6: Stratification shows indeed: no interactions between structure quality and living area.
    Figure 7: ICE plots also show no interations with any other feature. The interaction constraints of XGBoost did a good job.
    Figure 8: This two-dimensional PDP evaluated over all unique coordinates shows a realistic profile of house prices in Miami Dade County (mind the log scale).
    Figure 8: Same, but grouped by structure quality (5 is best). Since there is no interaction between location and structure quality, the plots are just shifted versions of each other. (You can’t really see it on the plots.)

    Naive Benchmark

    All methods in {hstats} are optimized for speed. But how fast are they compared to other implementations? Note that: this is just a simple benchmark run on a Windows notebook with Intel i7-8650U CPU.

    Note that {iml} offers a parallel backend, but we could not make it run with XGBoost and Windows. Let me know how fast it is using parallelism and Linux!

    Setup + benchmark on permutation importance

    Always using the full validation dataset and 10 repetitions.

    library(iml)  # Might benefit of multiprocessing, but on Windows with XGB models, this is not easy
    library(DALEX)
    library(ingredients)
    library(flashlight)
    library(bench)
    
    set.seed(1)
    
    # iml
    predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
    mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid, 
                         predict.function = predf)
    
    # DALEX
    ex <- DALEX::explain(fit, data = X_valid, y = y_valid)
    
    # flashlight (my slightly old fashioned package)
    fl <- flashlight(
      model = fit, data = valid, y = "log_price", predict_function = predf, label = "lm"
    )
      
    # Permutation importance: 10 repeats over full validation data (~2700 rows)
    bench::mark(
      iml = FeatureImp$new(mod, n.repetitions = 10, loss = "mse", compare = "difference"),
      dalex = feature_importance(ex, B = 10, type = "difference", n_sample = Inf),
      flashlight = light_importance(fl, v = x, n_max = Inf, m_repetitions = 10),
      hstats = perm_importance(fit, X = X_valid, y = y_valid, m_rep = 10, verbose = FALSE),
      check = FALSE,
      min_iterations = 3
    )
    
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.58s    1.58s     0.631   209.4MB    2.73      3    13      4.76s
    # dalex      566.21ms 586.91ms     1.72     34.6MB    0.572     3     1      1.75s
    # flashlight 587.03ms 613.15ms     1.63     27.1MB    1.63      3     3      1.84s
    # hstats     353.78ms 360.57ms     2.79     27.2MB    0         3     0      1.08s

    {hstats} is about 30% faster as the second, {DALEX}.

    Partial dependence

    Here, we study the time for crunching partial dependence of a continuous feature and a discrete feature.

    # Partial dependence (cont)
    v <- "tot_lvg_area"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, grid.size = 50, method = "pdp"),
      dalex = partial_dependence(ex, variables = v, N = Inf, grid_points = 50),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf, n_bins = 50),
      hstats = partial_dep(fit, v = v, X = X_valid, grid_size = 50, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml           1.11s    1.13s     0.887   376.3MB     3.84     3    13      3.38s
    # dalex      782.13ms 783.08ms     1.24    192.8MB     2.90     3     7      2.41s
    # flashlight 367.73ms  372.5ms     2.68     67.9MB     2.68     3     3      1.12s
    # hstats     220.88ms  222.5ms     4.50     14.2MB     0        3     0   666.33ms
     
    # Partial dependence (discrete)
    v <- "structure_quality"
    bench::mark(
      iml = FeatureEffect$new(mod, feature = v, method = "pdp", grid.points = 1:5),
      dalex = partial_dependence(ex, variables = v, N = Inf, variable_type = "categorical", grid_points = 5),
      flashlight = light_profile(fl, v = v, pd_n_max = Inf),
      hstats = partial_dep(fit, v = v, X = X_valid, n_max = Inf),
      check = FALSE,
      min_iterations = 3
    )
    # expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    # iml            90ms     96ms     10.6    13.29MB     7.06     3     2      283ms
    # dalex       170.6ms  174.4ms      5.73   20.55MB     2.87     2     1      349ms
    # flashlight   40.8ms   43.8ms     23.1     6.36MB     2.10    11     1      476ms
    # hstats       23.5ms   24.4ms     40.6     1.53MB     2.14    19     1      468ms

    {hstats} is 1.5 to 2 times faster than {flashlight}, and about four times as fast as the other packages. It’s memory foodprint is much lower.

    H-statistics

    How fast can overall H-statistics be computed? How fast can it do pairwise calculations?

    {DALEX} does not offer these statistics yet. {iml} was the first model-agnostic implementation of H-statistics I am aware of. It uses quantile approximation by default, but we purposely force it to calculate exact, in order to compare the numbers. Thus, we made it slower than it actually is.

    # H-Stats -> we use a subset of 500 rows
    X_v500 <- X_valid[1:500, ]
    mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
    fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))
    
    # iml  # 225s total, using slow exact calculations
    system.time(  # 90s
      iml_overall <- Interaction$new(mod500, grid.size = 500)
    )
    system.time(  # 135s for all combinations of latitude
      iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
    )
    
    # flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
    system.time(  # 12s
      fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
    )
    system.time(  # 2s
      fl_pairwise <- light_interaction(
        fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
      )
    )
    
    # hstats: 3s total
    system.time({
      H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
      hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
      hstats_pairwise <- h2_pairwise(H, squared = FALSE, zero = FALSE)
    }
    )
    
    # Overall statistics correspond exactly
    iml_overall$results |> filter(.interaction > 1e-6)
    #     .feature .interaction
    # 1:  latitude    0.2458269
    # 2: longitude    0.2458269
    
    fl_overall$data |> subset(value > 0, select = c(variable, value))
    #   variable  value
    # 1 latitude  0.246
    # 2 longitude 0.246
    
    hstats_overall
    # longitude  latitude 
    # 0.2458269 0.2458269 
    
    # Pairwise results match as well
    iml_pairwise$results |> filter(.interaction > 1e-6)
    #              .feature .interaction
    # 1: longitude:latitude    0.3942526
    
    fl_pairwise$data |> subset(value > 0, select = c(variable, value))
    # latitude:longitude 0.394
    
    hstats_pairwise
    # latitude:longitude 
    # 0.3942526 
    • {hstats} is about four times as fast as {flashlight}.
    • Since one often want to study relative and absolute H-statistics, in practice, the speed-up would be about a factor of eight.
    • In multi-classification/multi-output settings with m categories, the speed-up would be even m times larger.
    • The fast approximation via quantile binning is again a factor of four faster. The difference would diminish if we would calculate many pairwise or three-way H-statistics.
    • Forcing all three packages to calculate exact statistics, all results match.

    Wrap-Up

    • {hstats} is much faster than other XAI packages, at least in our use-case. This includes H-statistics, permutation importance, and partial dependence. Note that making good benchmarks is not my strength, so forgive any bias in the results.
    • The memory foodprint is lower as well.
    • With multivariate output, the potential is even larger.
    • H-Statistics match other implementations.

    Try it out!

    The full R code in one piece is here.

  • It’s the interactions

    What makes a ML model a black-box? It is the interactions. Without any interactions, the ML model is additive and can be exactly described.

    Studying interaction effects of ML models is challenging. The main XAI approaches are:

    1. Looking at ICE plots, stratified PDP, and/or 2D PDP.
    2. Study vertical scatter in SHAP dependence plots, or even consider SHAP interaction values.
    3. Check partial-dependence based H-statistics introduced in Friedman and Popescu (2008), or related statistics.

    This post is mainly about the third approach. Its beauty is that we get information about all interactions. The downside: it is as good/bad as partial dependence functions. And: the statistics are computationally very expensive to compute (of order n^2).

    Different R packages offer some of these H-statistics, including {iml}, {gbm}, {flashlight}, and {vivid}. They all have their limitations. This is why I wrote the new R package {hstats}:

    • It is very efficient.
    • Has a clean API. DALEX explainers and meta-learners (mlr3, Tidymodels, caret) work out-of-the-box.
    • Supports multivariate predictions, including classification models.
    • Allows to calculate unnormalized H-statistics. They help to compare pairwise and three-way statistics.
    • Contains fast multivariate ICE/PDPs with optional grouping variable.

    In Python, there is the very interesting project artemis. I will write a post on it later.

    Statistics supported by {hstats}

    Furthermore, a global measure of non-additivity (proportion of prediction variability unexplained by main effects), and a measure of feature importance is available. For technical details and references, check the following pdf or github.

    Classification example

    Let’s fit a probability random forest on iris species.

    library(ranger)
    library(ggplot2)
    library(hstats)
    
    v <- setdiff(colnames(iris), "Species")
    fit <- ranger(Species ~ ., data = iris, probability = TRUE, seed = 1)
    s <- hstats(fit, v = v, X = iris)  # 8 seconds run-time
    s
    # Proportion of prediction variability unexplained by main effects of v:
    #      setosa  versicolor   virginica 
    # 0.002705945 0.065629375 0.046742035
    
    plot(s, normalize = FALSE, squared = FALSE) +
      ggtitle("Unnormalized statistics") +
      scale_fill_viridis_d(begin = 0.1, end = 0.9)
    
    ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width", n_max = 150) |> 
      plot(center = TRUE) +
      ggtitle("Centered ICE plots")
    
    Unnormalized H-statistics, i.e., values are roughly on the scale of the predictions (here: probabilities).
    Centered ICE plots per class.

    Interpretation:

    • The features with strongest interactions are Petal Length and Petal Width. These interactions mainly affect species “virginica” and “versicolor”. The effect for “setosa” is almost additive.
    • Unnormalized pairwise statistics show that the strongest absolute interaction happens indeed between Petal Length and Petal Width.
    • The centered ICE plots shows how the interaction manifests: The effect of Petal Length heavily depends on Petal Width, except for species “setosa”. Would a SHAP analysis show the same?

    DALEX example

    Here, we consider a random forest regression on “Sepal.Length”.

    library(DALEX)
    library(ranger)
    library(hstats)
    
    set.seed(1)
    
    fit <- ranger(Sepal.Length ~ ., data = iris)
    ex <- explain(fit, data = iris[-1], y = iris[, 1])
    
    s <- hstats(ex)  # 2 seconds
    s  # Non-additivity index 0.054
    plot(s)
    plot(ice(ex, v = "Sepal.Width", BY = "Petal.Width"), center = TRUE)
    
    H-statistics
    Centered ICE plot of strongest relative interactions.

    Interpretation

    • Petal Length and Width show the strongest overall associations. Since we are considering normalized statistics, we can say: “About 3.5% of prediction variability comes from interactions with Petal Length”.
    • The strongest relative pairwise interaction happens between Sepal Width and Petal Width: Again, because we study normalized H-statistics, we can say: “About 4% of total prediction variability of the two features Sepal Width and Petal Width can be attributed to their interactions.”
    • Overall, all interactions explain only about 5% of prediction variability (see text output).

    Try it out!

    The complete R script can be found here. More examples and background can be found on the Github page of the project.

  • Model Diagnostics in Python

    🚀Version 1.0.0 of the new Python package for model-diagnostics was just released on PyPI. If you use (machine learning or statistical or other) models to predict a mean, median, quantile or expectile, this library offers tools to assess the calibration of your models and to compare and decompose predictive model performance scores.🚀

    pip install model-diagnostics

    After having finished our paper (or better: user guide) “Model Comparison and Calibration Assessment: User Guide for Consistent Scoring Functions in Machine Learning and Actuarial Practice” last year, I realised that there is no Python package that supports the proposed diagnostic tools (which are not completely new). Most of the required building blocks are there, but putting them together to get a result amounts quickly to a large amount of code. Therefore, I decided to publish a new package.

    By the way, I really never wanted to write a plotting library. But it turned out that arranging results until they are ready to be visualised amounts to quite a large part of the source code. I hope this was worth the effort. Your feedback is very welcome, either here in the comments or as feature request or bug report under https://github.com/lorentzenchr/model-diagnostics/issues.

    For a jump start, I recommend to go directly to the two examples:

    To give a glimpse of the functionality, here are some short code snippets.

    from model_diagnostics.calibration import compute_bias
    from model_diagnostics.calibration import plot_reliability_diagram
    
    
    y_obs = list(range(10))
    y_pred = [2, 1, 3, 3, 6, 8, 5, 5, 8, 9.]
    plot_reliability_diagram(
        y_obs=y_obs,
        y_pred=y_pred,
        n_bootstrap=1000,
        confidence_level=0.9,
    )
    compute_bias(y_obs=y_obs, y_pred=y_pred)
    bias_meanbias_countbias_weightsbias_stderrp_value
    f64u32f64f64f64
    0.51010.00.4772610.322121
    from model_diagnostics.scoring import SquaredError, decompose
    
    
    decompose(
        y_obs=y_obs,
        y_pred=y_pred,
        scoring_function=SquaredError(),
    )
    miscalibrationdiscriminationuncertaintyscore
    f64f64f64f64
    1.2833337.2333338.252.3

    This score decomposition is additive (and unique):

    \begin{equation*}
    \mathrm{score} = \mathrm{miscalibration} - \mathrm{discrimination} + \mathrm{uncertainty}
    \end{equation*}

    As usual, the code snippets are collected in a notebook: https://github.com/lorentzenchr/notebooks/blob/master/blogposts/2023-07-16%20model-diagnostics.ipynb.

  • Geographic SHAP

    Lost in Translation between R and Python 10

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    This post is heavily based on the new {shapviz} vignette.

    Setting

    Besides other features, a model with geographic components contains features like

    • latitude and longitude,
    • postal code, and/or
    • other features that depend on location, e.g., distance to next restaurant.

    Like any feature, the effect of a single geographic feature can be described using SHAP dependence plots. However, studying the effect of latitude (or any other location dependent feature) alone is often not very illuminating – simply due to strong interaction effects and correlations with other geographic features.

    That’s where the additivity of SHAP values comes into play: The sum of SHAP values of all geographic components represent the total geographic effect, and this sum can be visualized as a heatmap or 3D scatterplot against latitude/longitude (or any other geographic representation).

    A first example

    For illustration, we will use a beautiful house price dataset containing information on about 14’000 houses sold in 2016 in Miami-Dade County. Some of the columns are as follows:

    • SALE_PRC: Sale price in USD: Its logarithm will be our model response.
    • LATITUDE, LONGITUDE: Coordinates
    • CNTR_DIST: Distance to central business district
    • OCEAN_DIST: Distance (ft) to the ocean
    • RAIL_DIST: Distance (ft) to the next railway track
    • HWY_DIST: Distance (ft) to next highway
    • TOT_LVG_AREA: Living area in square feet
    • LND_SQFOOT: Land area in square feet
    • structure_quality: Measure of building quality (1: worst to 5: best)
    • age: Age of the building in years

    (Italic features are geographic components.) For more background on this dataset, see Mayer et al [2].

    We will fit an XGBoost model to explain log(price) as a function of lat/long, size, and quality/age.

    devtools::install_github("ModelOriented/shapviz", dependencies = TRUE)
    library(xgboost)
    library(ggplot2)
    library(shapviz)  # Needs development version 0.9.0 from github
    
    head(miami)
    
    x_coord <- c("LATITUDE", "LONGITUDE")
    x_nongeo <- c("TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age")
    x <- c(x_coord, x_nongeo)
    
    # Train/valid split
    set.seed(1)
    ix <- sample(nrow(miami), 0.8 * nrow(miami))
    X_train <- data.matrix(miami[ix, x])
    X_valid <- data.matrix(miami[-ix, x])
    y_train <- log(miami$SALE_PRC[ix])
    y_valid <- log(miami$SALE_PRC[-ix])
    
    # Fit XGBoost model with early stopping
    dtrain <- xgb.DMatrix(X_train, label = y_train)
    dvalid <- xgb.DMatrix(X_valid, label = y_valid)
    
    params <- list(learning_rate = 0.2, objective = "reg:squarederror", max_depth = 5)
    
    fit <- xgb.train(
      params = params, 
      data = dtrain, 
      watchlist = list(valid = dvalid), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    %load_ext lab_black
    
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import fetch_openml
    
    df = fetch_openml(data_id=43093, as_frame=True)
    X, y = df.data, np.log(df.target)
    X.head()
    
    # Data split and model
    from sklearn.model_selection import train_test_split
    import xgboost as xgb
    
    x_coord = ["LONGITUDE", "LATITUDE"]
    x_nongeo = ["TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age"]
    x = x_coord + x_nongeo
    
    X_train, X_valid, y_train, y_valid = train_test_split(
        X[x], y, test_size=0.2, random_state=30
    )
    
    # Fit XGBoost model with early stopping
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dvalid = xgb.DMatrix(X_valid, label=y_valid)
    
    params = dict(learning_rate=0.2, objective="reg:squarederror", max_depth=5)
    
    fit = xgb.train(
        params=params,
        dtrain=dtrain,
        evals=[(dvalid, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    

    SHAP dependence plots

    Let’s first study selected SHAP dependence plots, evaluated on the validation dataset with around 2800 observations. Note that we could as well use the training data for this purpose, but it is a bit large.

    sv <- shapviz(fit, X_pred = X_valid)
    sv_dependence(
      sv, 
      v = c("TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"), 
      alpha = 0.2
    )
    import shap
    
    xgb_explainer = shap.Explainer(fit)
    shap_values = xgb_explainer(X_valid)
    
    v = ["TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"]
    shap.plots.scatter(shap_values[:, v], color=shap_values[:, v])
    SHAP dependence plots of selected features (Python output).

    Total coordindate effect

    And now the two-dimensional plot of the sum of SHAP values:

    sv_dependence2D(sv, x = "LONGITUDE", y = "LATITUDE") +
      coord_equal()
    shap_coord = shap_values[:, x_coord]
    plt.scatter(*list(shap_coord.data.T), c=shap_coord.values.sum(axis=1), s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    Sum of SHAP values on color scale against coordinates (Python output).

    The last plot gives a good impression on price levels, but note:

    1. Since we have modeled logarithmic prices, the effects are on relative scale (0.1 means about 10% above average).
    2. Due to interaction effects with non-geographic components, the location effects might depend on features like living area. This is not visible in above plot. We will modify the model now to improve this aspect.

    Two modifications

    We will now change above model in two ways, not unlike the model in Mayer et al [2].

    1. We will use additional geographic features like distance to railway track or to the ocean.
    2. We will use interaction constraints to allow only interactions between geographic features.

    The second step leads to a model that is additive in each non-geographic component and also additive in the combined location effect. According to the technical report of Mayer [1], SHAP dependence plots of additive components in a boosted trees model are shifted versions of corresponding partial dependence plots (evaluated at observed values). This allows a “Ceteris Paribus” interpretation of SHAP dependence plots of corresponding components.

    # Extend the feature set
    more_geo <- c("CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST")
    x2 <- c(x, more_geo)
    
    X_train2 <- data.matrix(miami[ix, x2])
    X_valid2 <- data.matrix(miami[-ix, x2])
    
    dtrain2 <- xgb.DMatrix(X_train2, label = y_train)
    dvalid2 <- xgb.DMatrix(X_valid2, label = y_valid)
    
    # Build interaction constraint vector
    ic <- c(
      list(which(x2 %in% c(x_coord, more_geo)) - 1),
      as.list(which(x2 %in% x_nongeo) - 1)
    )
    
    # Modify parameters
    params$interaction_constraints <- ic
    
    fit2 <- xgb.train(
      params = params, 
      data = dtrain2, 
      watchlist = list(valid = dvalid2), 
      early_stopping_rounds = 20,
      nrounds = 1000,
      callbacks = list(cb.print.evaluation(period = 100))
    )
    
    # SHAP analysis
    sv2 <- shapviz(fit2, X_pred = X_valid2)
    
    # Two selected features: Thanks to additivity, structure_quality can be read as 
    # Ceteris Paribus
    sv_dependence(sv2, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
    
    # Total geographic effect (Ceteris Paribus thanks to additivity)
    sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
      coord_equal()
    # Extend the feature set
    more_geo = ["CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST"]
    x2 = x + more_geo
    
    X_train2, X_valid2 = train_test_split(X[x2], test_size=0.2, random_state=30)
    
    dtrain2 = xgb.DMatrix(X_train2, label=y_train)
    dvalid2 = xgb.DMatrix(X_valid2, label=y_valid)
    
    # Build interaction constraint vector
    ic = [x_coord + more_geo, *[[z] for z in x_nongeo]]
    
    # Modify parameters
    params["interaction_constraints"] = ic
    
    fit2 = xgb.train(
        params=params,
        dtrain=dtrain2,
        evals=[(dvalid2, "valid")],
        verbose_eval=100,
        early_stopping_rounds=20,
        num_boost_round=1000,
    )
    
    # SHAP analysis
    xgb_explainer2 = shap.Explainer(fit2)
    shap_values2 = xgb_explainer2(X_valid2)
    
    v = ["structure_quality", "LONGITUDE"]
    shap.plots.scatter(shap_values2[:, v], color=shap_values2[:, v])
    
    # Total location effect
    shap_coord2 = shap_values2[:, x_coord]
    c = shap_values2[:, x_coord + more_geo].values.sum(axis=1)
    plt.scatter(*list(shap_coord2.data.T), c=c, s=4)
    ax = plt.gca()
    ax.set_aspect("equal", adjustable="box")
    plt.colorbar()
    plt.title("Total location effect")
    plt.show()
    SHAP dependence plots of an additive feature (structure quality, no vertical scatter per unique feature value) and one of the geographic features (Python output).
    Sum of all geographic features (color) against coordinates. There are no interactions to non-geographic features, so the effect can be read Ceteris Paribus (Python output).

    Again, the resulting total geographic effect looks reasonable.

    Wrap-Up

    • SHAP values of all geographic components in a model can be summed up and plotted on the color scale against coordinates (or some other geographic representation). This gives a lightning fast impression of the location effects.
    • Interaction constraints between geographic and non-geographic features lead to Ceteris Paribus interpretation of total geographic effects.

    The Python and R notebooks can be found here:

    References

    1. Mayer, Michael. 2022. “SHAP for Additively Modeled Features in a Boosted Trees Model.” https://arxiv.org/abs/2207.14490.
    2. Mayer, Michael, Steven C. Bourassa, Martin Hoesli, and Donato Flavio Scognamiglio. 2022. “Machine Learning Applications to Land and Structure Valuation.” Journal of Risk and Financial Management.

  • SHAP + XGBoost + Tidymodels = LOVE

    In this recent post, we have explained how to use Kernel SHAP for interpreting complex linear models. As plotting backend, we used our fresh CRAN package “shapviz“.

    “shapviz” has direct connectors to a couple of packages such as XGBoost, LightGBM, H2O, kernelshap, and more. Multiple times people asked me how to combine shapviz when the XGBoost model was fitted with Tidymodels. The workflow was not 100% clear to me as well, but the answer is actually very simple, thanks to Julia’s post where the plots were made with SHAPforxgboost, another cool package for visualization of SHAP values.

    Example with shiny diamonds

    Step 1: Preprocessing

    We first write the data preprocessing recipe and apply it to the data rows that we want to explain. In our case, its 1000 randomly sampled diamonds.

    library(tidyverse)
    library(tidymodels)
    library(shapviz)
    
    # Integer encode factors
    dia_recipe <- diamonds %>%
      recipe(price ~ carat + cut + clarity + color) %>% 
      step_integer(all_nominal())
    
    # Will explain THIS dataset later
    set.seed(2)
    dia_small <- diamonds[sample(nrow(diamonds), 1000), ]
    dia_small_prep <- bake(
      prep(dia_recipe), 
      has_role("predictor"),
      new_data = dia_small, 
      composition = "matrix"
    )
    head(dia_small_prep)
    
    #     carat cut clarity color
    #[1,]  0.57   5       4     4
    #[2,]  1.01   5       2     1
    #[3,]  0.45   1       4     3
    #[4,]  1.04   4       6     5
    #[5,]  0.90   3       6     4
    #[6,]  1.20   3       4     6
    

    Step 2: Fit Model

    The next step is to tune and build the model. For simplicity, we skipped the tuning part. Bad, bad 🙂

    # Just for illustration - in practice needs tuning!
    xgboost_model <- boost_tree(
      mode = "regression",
      trees = 200,
      tree_depth = 5,
      learn_rate = 0.05,
      engine = "xgboost"
    )
    
    dia_wf <- workflow() %>%
      add_recipe(dia_recipe) %>%
      add_model(xgboost_model)
    
    fit <- dia_wf %>%
      fit(diamonds)

    Step 3: SHAP Analysis

    We now need to call shapviz() on the fitted model. In order to have neat interpretations with the original factor labels, we not only pass the prediction data prepared in Step 1 via bake(), but also the original data structure.

    shap <- shapviz(extract_fit_engine(fit), X_pred = dia_small_prep, X = dia_small)
    
    sv_importance(shap, kind = "both", show_numbers = TRUE)
    sv_dependence(shap, "carat", color_var = "auto")
    sv_dependence(shap, "clarity", color_var = "auto")
    sv_force(shap, row_id = 1)
    sv_waterfall(shap, row_id = 1)
    
    Variable importance plot overlaid with SHAP summary beeswarms
    Dependence plot for carat. Note that clarity is shown with original labels, not only integers.
    Dependence plot for clarity. Note again that the x-scale uses the original factor levels, not the integer encoded values.
    Force plot of the first observation
    Waterfall plot for the first observation

    Summary

    Making SHAP analyses with XGBoost Tidymodels is super easy.

    The complete R script can be found here.

  • Dplyr-style without dplyr

    One of the reasons why we love the “dplyr” package: it plays so well together with the forward pipe operator `%>%` from the “magrittr” package. Actually, it is not a coincidence that both packages were released quite at the same time, in 2014.

    What does the pipe do? It puts the object on its left as the first argument into the function on its right: iris %>% head() is a funny way of writing head(iris). It helps to avoid long function chains like f(g(h(x))), or repeated assignments.

    In 2021 and version 4.1, R has received its native forward pipe operator |> so that we can write nice code like this:

    Imagine this without pipe…

    Since version 4.2, the piped object can be referenced by the underscore _, but just once for now, see an example below.

    To use the native pipe via CTRL-SHIFT-M in Posit/RStudio, tick this:

    Combined with the many great functions from the standard distribution of R, we can get a real “dplyr” feeling without even loading dplyr. Don’t get me wrong: I am a huge fan of the whole Tidyverse! But it is a great way to learn “Standard R”.

    Data chains

    Here a small selection of standard functions playing well together with the pipe: They take a data frame and return a modified data frame:

    • subset(): Select rows and columns of data frame
    • transform(): Add or overwrite columns in data frame
    • aggregate(): Grouped calculations
    • rbind(), cbind(): Bind rows/columns of data frame/matrix
    • merge(): Join data frames by key
    • head(), tail(): First/last few elements of object
    • reshape(): Transposition/Reshaping of data frame (no, I don’t understand the interface)
    library(ggplot2)  # Need diamonds
    
    # What does the native pipe do?
    quote(diamonds |> head())
    
    # OUTPUT
    # head(diamonds)
    
    # Grouped statistics
    diamonds |> 
      aggregate(cbind(price, carat) ~ color, FUN = mean)
    
    # OUTPUT
    #   color    price     carat
    # 1     D 3169.954 0.6577948
    # 2     E 3076.752 0.6578667
    # 3     F 3724.886 0.7365385
    # 4     G 3999.136 0.7711902
    # 5     H 4486.669 0.9117991
    # 6     I 5091.875 1.0269273
    # 7     J 5323.818 1.1621368
    
    # Join back grouped stats to relevant columns
    diamonds |> 
      subset(select = c(price, color, carat)) |> 
      transform(price_per_color = ave(price, color)) |> 
      head()
    
    # OUTPUT
    #   price color carat price_per_color
    # 1   326     E  0.23        3076.752
    # 2   326     E  0.21        3076.752
    # 3   327     E  0.23        3076.752
    # 4   334     I  0.29        5091.875
    # 5   335     J  0.31        5323.818
    # 6   336     J  0.24        5323.818
    
    # Plot transformed values
    diamonds |> 
      transform(
        log_price = log(price),
        log_carat = log(carat)
      ) |> 
      plot(log_price ~ log_carat, col = "chartreuse4", pch = ".", data = _)
    A simple scatterplot

    The plot does not look quite as sexy as “ggplot2”, but its a start.

    Other chains

    The pipe not only works perfectly with functions that modify a data frame. It also shines with many other functions often applied in a nested way. Here two examples:

    # Distribution of color within clarity
    diamonds |> 
      subset(select = c(color, clarity)) |> 
      table() |> 
      prop.table(margin = 2) |> 
      addmargins(margin = 1) |> 
      round(3)
    
    # OUTPUT
    # clarity
    # color      I1   SI2   SI1   VS2   VS1  VVS2  VVS1    IF
    #     D   0.057 0.149 0.159 0.138 0.086 0.109 0.069 0.041
    #     E   0.138 0.186 0.186 0.202 0.157 0.196 0.179 0.088
    #     F   0.193 0.175 0.163 0.180 0.167 0.192 0.201 0.215
    #     G   0.202 0.168 0.151 0.191 0.263 0.285 0.273 0.380
    #     H   0.219 0.170 0.174 0.134 0.143 0.120 0.160 0.167
    #     I   0.124 0.099 0.109 0.095 0.118 0.072 0.097 0.080
    #     J   0.067 0.052 0.057 0.060 0.066 0.026 0.020 0.028
    #     Sum 1.000 1.000 1.000 1.000 1.000 1.000 1.000 1.000
    
    # Barplot from discrete column
    diamonds$color |> 
      table() |> 
      prop.table() |> 
      barplot(col = "chartreuse4", main = "Color")

    Wrap up

    • Piping is fun with and without dplyr.
    • It is a great motivation to learn standard R

    The complete R script can be found here.

  • Interpret Complex Linear Models with SHAP within Seconds

    A linear model with complex interaction effects can be almost as opaque as a typical black-box like XGBoost.

    XGBoost models are often interpreted with SHAP (Shapley Additive eXplanations): Each of e.g. 1000 randomly selected predictions is fairly decomposed into contributions of the features using the extremely fast TreeSHAP algorithm, providing a rich interpretation of the model as a whole. TreeSHAP was introduced in the Nature publication by Lundberg and Lee (2020).

    Can we do the same for non-tree-based models like a complex GLM or a neural network? Yes, but we have to resort to slower model-agnostic SHAP algorithms:

    In the limit, the two algorithms provide the same SHAP values.

    House prices

    We will use a great dataset with 14’000 house prices sold in Miami in 2016. The dataset was kindly provided by Prof. Steven Bourassa for research purposes and can be found on OpenML.

    The model

    We will model house prices by a Gamma regression with log-link. The model includes factors, linear components and natural cubic splines. The relationship of living area and distance to central district is modeled by letting the spline bases of the two features interact.

    library(OpenML)
    library(tidyverse)
    library(splines)
    library(doFuture)
    library(kernelshap)
    library(shapviz)
    
    raw <- OpenML::getOMLDataSet(43093)$data
    
    # Lump rare level 3 and log transform the land size
    prep <- raw %>%
      mutate(
        structure_quality = factor(structure_quality, labels = c(1, 2, 4, 4, 5)),
        log_landsize = log(LND_SQFOOT)
      )
    
    # 1) Build model
    xvars <- c("TOT_LVG_AREA", "log_landsize", "structure_quality",
               "CNTR_DIST", "age", "month_sold")
    
    fit <- glm(
      SALE_PRC ~ ns(log(CNTR_DIST), df = 4) * ns(log(TOT_LVG_AREA), df = 4) +
        log_landsize + structure_quality + ns(age, df = 4) + ns(month_sold, df = 4),
      family = Gamma("log"),
      data = prep
    )
    summary(fit)
    
    # Selected coefficients:
    # log_landsize: 0.22559  
    # structure_quality4: 0.63517305 
    # structure_quality5: 0.85360956   
    

    The model has 37 parameters. Some of the estimates are shown.

    Interpretation

    The workflow of a SHAP analysis is as follows:

    1. Sample 1000 rows to explain
    2. Sample 100 rows as background data used to estimate marginal expectations
    3. Calculate SHAP values. This can be done fully in parallel by looping over the rows selected in Step 1
    4. Analyze the SHAP values

    Step 2 is the only additional step compared with TreeSHAP. It is required both for SHAP sampling values and Kernel SHAP.

    # 1) Select rows to explain
    set.seed(1)
    X <- prep[sample(nrow(prep), 1000), xvars]
    
    # 2) Select small representative background data
    bg_X <- prep[sample(nrow(prep), 100), ]
    
    # 3) Calculate SHAP values in fully parallel mode
    registerDoFuture()
    plan(multisession, workers = 6)  # Windows
    # plan(multicore, workers = 6)   # Linux, macOS, Solaris
    
    system.time( # <10 seconds
      shap_values <- kernelshap(
        fit, X, bg_X = bg_X, parallel = T, parallel_args = list(.packages = "splines")
      )
    )

    Thanks to parallel processing and some implementation tricks, we were able to decompose 1000 predictions within 10 seconds! By default, kernelshap() uses exact calculations up to eight features (exact regarding the background data), which would need an infinite amount of Monte-Carlo-sampling steps.

    Note that glm() has a very efficient predict() function. GAMs, neural networks, random forests etc. usually take more time, e.g. 5 minutes to do the crunching.

    Analyze the SHAP values

    # 4) Analyze them
    sv <- shapviz(shap_values)
    
    sv_importance(sv, show_numbers = TRUE) +
      ggtitle("SHAP Feature Importance")
    
    sv_dependence(sv, "log_landsize")
    sv_dependence(sv, "structure_quality")
    sv_dependence(sv, "age")
    sv_dependence(sv, "month_sold")
    sv_dependence(sv, "TOT_LVG_AREA", color_var = "auto")
    sv_dependence(sv, "CNTR_DIST", color_var = "auto")
    
    # Slope of log_landsize: 0.2255946
    diff(sv$S[1:2, "log_landsize"]) / diff(sv$X[1:2, "log_landsize"])
    
    # Difference between structure quality 4 and 5: 0.2184365
    diff(sv$S[2:3, "structure_quality"])
    SHAP Importance: Living area and the distance to the central district are the two most important predictors. The month (within 2016) impacts the predicted prices by +-1.3% on average.
    SHAP dependence plot of “log_landsize”. The effect is linear. The slope 0.22559 agrees with the model coefficient.
    Dependence plot for “structure_quality”: The difference between structure quality 4 and 5 is 0.2184365. This equals the difference in regression coefficients.
    Dependence plot of “living_area”: The effect is very steep. The more central, the steeper. We cannot easily compare these numbers with the output of the linear regression.

    Summary

    • Interpreting complex linear models with SHAP is an option. There seems to be a correspondence between regression coefficients and SHAP dependence, at least for additive components.
    • Kernel SHAP in R is fast. For models with slower predict() functions (e.g. GAMs, random forests, or neural nets), we often need to wait a couple of minutes.

    The complete R script can be found here.

  • Histograms, Gradient Boosted Trees, Group-By Queries and One-Hot Encoding

    This post shows how filling histograms can be done in very different ways thereby connecting very different areas: from gradient boosted trees to SQL queries to one-hot encoding. Let’s jump into it!

    Modern gradient boosted trees (GBT) like LightGBM, XGBoost and the HistGradientBoostingRegressor of scikit-learn all use two techniques on top of standard gradient boosting:

    • 2nd order Taylor expansion of the loss which amounts to using gradients and hessians.
    • One histogram per feature: bin the feature and fill the histogram with the gradients and hessians.

    The filling of the histograms is often the bottleneck when fitting GBTs. While filling a single histogram is very fast, this operation is executed many times: for each boosting round, for each tree split and for each feature. This is the reason why GBT implementations have dedicated routines for it. We look into this operation from different angles.

    For the coming (I)Python code snippets to work (# %% indicates a new notebook cell), we need the following imports.

    import duckdb                    # v0.5.1
    import matplotlib.pyplot as plt  # v.3.6.1
    from matplotlib.ticker import MultipleLocator
    import numpy as np               # v1.23.4
    import pandas as pd              # v1.5.0
    import pyarrow as pa             # v9.0.0
    import tabmat                    # v3.1.2
    
    from sklearn.ensemble._hist_gradient_boosting.histogram import (
        _build_histogram_root,
    )                                # v1.1.2
    from sklearn.ensemble._hist_gradient_boosting.common import (
      HISTOGRAM_DTYPE
    )

    Naive Histogram Visualisation

    As a starter, we create a small table with two columns: bin index and value of the hessian.

    def highlight(df):
        if df["bin"] == 0:
            return ["background-color: rgb(255, 128, 128)"] * len(df)
        elif df["bin"] == 1:
            return ["background-color: rgb(128, 255, 128)"] * len(df)
        else:
            return ['background-color: rgb(128, 128, 255)'] * len(df)
    
    df = pd.DataFrame({"bin": [0, 2, 1, 0, 1], "hessian": [1.5, 1, 2, 2.5, 3]})
    df.style.apply(highlight, axis=1)
      bin hessian
    0 0 1.500000
    1 2 1.000000
    2 1 2.000000
    3 0 2.500000
    4 1 3.000000

    A histogram then sums up all the hessian values belonging to the same bin. The result looks like the following.

    Above table visualised as histogram

    Dedicated Method

    We simulate filling the histogram of a single feature. Therefore, we draw 1,000,000 random variables for gradients and hessians as well as the bin indices.

    import duckdb
    import pyarrow as pa
    import numpy as np
    import tabmat
    
    from sklearn.ensemble._hist_gradient_boosting.histogram import (
        _build_histogram_root,
    )
    from sklearn.ensemble._hist_gradient_boosting.common import HISTOGRAM_DTYPE
    
    
    rng = np.random.default_rng(42)
    n_obs = 1000_000
    n_bins = 256
    binned_feature = rng.integers(0, n_bins, size=n_obs, dtype=np.uint8)
    gradients = rng.normal(size=n_obs).astype(np.float32)
    hessians = rng.lognormal(size=n_obs).astype(np.float32)

    Now we use the dedicated (and private!) and single-threaded method _build_histogram_root from sckit-learn to fill a histogram.

    hist_root = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE)
    %time _build_histogram_root(0, binned_feature, gradients, hessians, hist_root)
    # Wall time: 1.38 ms

    This executes in around 1.4 ms. This is quite fast. But again, imagine 100 boosting rounds with 10 tree splits on average and 100 features. This means this is done around 100,000 times and would therefore take roughly 2 minutes.

    Let’s have a look at the first 5 bins:

    hist_root[:, 0:5]
    array([[(-79.72386998, 6508.89500265, 3894),
            ( 37.98393589, 6460.63222205, 3998),
            ( 53.54256977, 6492.22722797, 3805),
            ( 21.19542398, 6797.34159299, 3928),
            ( 16.24716742, 6327.03757573, 3875)]],
          dtype=[('sum_gradients', '<f8'), ('sum_hessians', '<f8'), ('count', '<u4')])

    SQL Group-By Query

    Someone familiar with SQL and database queries might immediately see how this task can be formulated as SQL group-by-aggregate query. To demonstrate it on our simulated data, we use DuckDB as well as Apache Arrow (the file format as well as the Python library pyarrow). You can read more about DuckDB in our post DuckDB: Quacking SQL.

    # %%
    con = duckdb.connect()
    arrow_table = pa.Table.from_pydict(
        {
            "bin": binned_feature,
            "gradients": gradients,
            "hessians": hessians,
    })
    # Read data once to make timing fairer
    arrow_result = con.execute("SELECT * FROM arrow_table")
    
    # %%
    %%time
    arrow_result = con.execute("""
    SELECT
        bin as bin,
        SUM(gradients) as sum_gradients,
        SUM(hessians) as sum_hessians,
        COUNT() as count
    FROM arrow_table
    GROUP BY bin
    """).arrow()
    # Wall time: 6.52 ms

    On my laptop, this takes about 6.5 ms and, upon sorting, gives the same results:

    arrow_result.sort_by("bin").slice(length=5)
    pyarrow.Table
    bin: uint8
    sum_gradients: double
    sum_hessians: double
    count: int64
    ----
    bin: [[0,1,2,3,4]]
    sum_gradients: [[-79.72386997545254,37.98393589106854,53.54256977112527,21.195423980039777,16.247167424764484]]
    sum_hessians: [[6508.895002648234,6460.632222048938,6492.227227974683,6797.341592986137,6327.037575732917]]
    count: [[3894,3998,3805,3928,3875]]

    As we have the table as an Arrow table, we can stay within pyarrow:

    %%time
    arrow_result = arrow_table.group_by("bin").aggregate(
        [
            ("gradients", "sum"),
            ("hessians", "sum"),
            ("bin", "count"),
        ]
    )
    # Wall time: 10.8 ms

    The fact that DuckDB is faster than Arrow on this task might have to do with the large invested effort on parallelised group-by operations, see their post Parallel Grouped Aggregation in DuckDB for more infos.

    One-Hot encoded Matrix Multiplication

    I think it is very interesting that filling histograms can be written as a matrix multiplication! The trick is to view the feature as a categorical feature and use its one-hot encoded matrix representation. This blows up memory, of course. Note that one-hot encoding is usually met with generalized linear models (GLM) in order to incorporate nominal categorical feature variables with no internal ordering in the design matrix.

    For our demonstration, we use a numpy index trick to construct the one-hot encoded matrix employing the fact that the binned feature already contains the right indices.

    # %%
    %%time
    m_OHE = np.eye(n_bins)[binned_feature].T
    vec = np.column_stack((gradients, hessians, np.ones_like(gradients)))
    # Wall time: 770 ms
    
    # %%
    %time result_ohe = m_OHE @ vec
    # Wall time: 199 ms
    
    # %%
    result_ohe[:5]
    array([[ -79.72386998, 6508.89500265, 3894.        ],
           [  37.98393589, 6460.63222205, 3998.        ],
           [  53.54256977, 6492.22722797, 3805.        ],
           [  21.19542398, 6797.34159299, 3928.        ],
           [  16.24716742, 6327.03757573, 3875.        ]])

    This is way slower, but, somehow surprisingly, produces the same result.

    The one-hot encoded matrix is very sparse, with only one non-zero value per column, i.e. only one out of 256 (number of bins) values is non-zero. This structure can be exploited to reduce both CPU time as well as memory consumption, with the help of the package tabmat that was built to accelerate GLMs. Unfortunately, tabmat only provides a matrix-vector multiplication (and the sandwich product, of course), but no matrix-matrix multiplication. So we have to do a little extra work.

    # %%
    %time m_categorical = tabmat.CategoricalMatrix(cat_vec=binned_feature)
    # Wall time: 21.5 ms
    
    # %%
    # tabmat needs contigous arrays with dtype = Python float = float64
    vec = np.asfortranarray(vec, dtype=float)
    
    # %%
    %%time
    tabmat_result = np.column_stack(
        (
            vec[:, 0] @ m_categorical,
            vec[:, 1] @ m_categorical,
            vec[:, 2] @ m_categorical,
        )
    )
    # Wall time: 4.82 ms
    
    # %%
    tabmat_result[0:5]
    array([[ -79.72386998, 6508.89500265, 3894.        ],
           [  37.98393589, 6460.63222205, 3998.        ],
           [  53.54256977, 6492.22722797, 3805.        ],
           [  21.19542398, 6797.34159299, 3928.        ],
           [  16.24716742, 6327.03757573, 3875.        ]])

    While the timing of this approach is quite good, the construction of a CategoricalMatrix requires more time than the matrix-vector multiplication.

    Conclusion

    In the end, the special (Cython) routine of scikit-learn ist the fastest of our tested methods for filling histograms. The other GBT libraries have their own even more specialised routines which might be a reason for even faster fit times. What we learned in this post is that this seemingly simple task plays a very crucial part in modern GBTs and can be accomplished by very different approaches. These different approaches uncover connections of algorithms of quite different domains.

    The full code as ipython notebook can be found at https://github.com/lorentzenchr/notebooks/blob/master/blogposts/2022-10-31%20histogram-GBT-GroupBy-OHE.ipynb.

  • Kernel SHAP in R and Python

    Lost in Translation between R and Python 9

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    Kernel SHAP

    SHAP is one of the most used model interpretation technique in Machine Learning. It decomposes predictions into additive contributions of the features in a fair way. For tree-based methods, the fast TreeSHAP algorithm exists. For general models, one has to resort to computationally expensive Monte-Carlo sampling or the faster Kernel SHAP algorithm. Kernel SHAP uses a regression trick to get the SHAP values of an observation with a comparably small number of calls to the predict function of the model. Still, it is much slower than TreeSHAP.

    Two good references for Kernel SHAP:

    1. Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017.
    2. Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.

    In our last post, we introduced our new “kernelshap” package in R. Since then, the package has been substantially improved, also by the big help of David Watson:

    1. The package now supports multi-dimensional predictions.
    2. It received a massive speed-up
    3. Additionally, parallel computing can be activated for even faster calculations.
    4. The interface has become more intuitive.
    5. If the number of features is small (up to ten or eleven), it can provide exact Kernel SHAP values just like the reference Python implementation.
    6. For a larger number of features, it now uses partly-exact (“hybrid”) calculations, very similar to the logic in the Python implementation.

    With those changes, the R implementation is about to meet the Python version at eye level.

    Example with four features

    In the following, we use the diamonds data to fit a linear regression with

    • log(price) as response
    • log(carat) as numeric feature
    • clarity, color and cut as categorical features (internally dummy encoded)
    • interactions between log(carat) and the other three “C” variables. Note that the interactions are very weak

    Then, we calculate SHAP decompositions for about 1000 diamonds (every 53th diamond), using 120 diamonds as background dataset. In this case, both R and Python will use exact calculations based on m=2^4 – 2 = 14 possible binary on-off vectors (a value of 1 representing a feature value picked from the original observation, a value of 0 a value picked from the background data).

    library(ggplot2)
    library(kernelshap)
    
    # Turn ordinal factors into unordered
    ord <- c("clarity", "color", "cut")
    diamonds[, ord] <- lapply(diamonds[ord], factor, ordered = FALSE)
    
    # Fit model
    fit <- lm(log(price) ~ log(carat) * (clarity + color + cut), data = diamonds)
    
    # Subset of 120 diamonds used as background data
    bg_X <- diamonds[seq(1, nrow(diamonds), 450), ]
    
    # Subset of 1018 diamonds to explain
    X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)]
    
    # Exact KernelSHAP (5 seconds)
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X)  
    )
    ks
    
    # SHAP values of first 2 observations:
    #          carat     clarity     color        cut
    # [1,] -2.050074 -0.28048747 0.1281222 0.01587382
    # [2,] -2.085838  0.04050415 0.1283010 0.03731644
    
    # Using parallel backend
    library("doFuture")
    
    registerDoFuture()
    plan(multisession, workers = 2)  # Windows
    # plan(multicore, workers = 2)   # Linux, macOS, Solaris
    
    # 3 seconds on second call
    system.time(
      ks3 <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)  
    )
    
    # Visualization
    library(shapviz)
    
    sv <- shapviz(ks)
    sv_importance(sv, "bee")
    import numpy as np
    import pandas as pd
    from plotnine.data import diamonds
    from statsmodels.formula.api import ols
    from shap import KernelExplainer
    
    # Turn categoricals into integers because, inconveniently, kernel SHAP
    # requires numpy array as input
    ord = ["clarity", "color", "cut"]
    x = ["carat"] + ord
    diamonds[ord] = diamonds[ord].apply(lambda x: x.cat.codes)
    X = diamonds[x].to_numpy()
    
    # Fit model with interactions and dummy variables
    fit = ols(
      "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", 
      data=diamonds
    ).fit()
    
    # Background data (120 rows)
    bg_X = X[0:len(X):450]
    
    # Define subset of 1018 diamonds to explain
    X_small = X[0:len(X):53]
    
    # Calculate KernelSHAP values
    ks = KernelExplainer(
      model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), 
      data = bg_X
    )
    sv = ks.shap_values(X_small)  # 74 seconds
    sv[0:2]
    
    # array([[-2.05007406, -0.28048747,  0.12812216,  0.01587382],
    #        [-2.0858379 ,  0.04050415,  0.12830103,  0.03731644]])
    SHAP summary plot (R model)

    The results match, hurray!

    Example with nine features

    The computation effort of running exact Kernel SHAP explodes with the number of features. For nine features, the number of relevant on-off vectors is 2^9 – 2 = 510, i.e. about 36 times larger than with four features.

    We now modify above example, adding five additional features to the model. Note that the model structure is completely non-sensical. We just use it to get a feeling about what impact a 36 times larger workload has.

    Besides exact calculations, we use an almost exact hybrid approach for both R and Python, using 126 on-off vectors (p*(p+1) for the exact part and 4p for the sampling part, where p is the number of features), resulting in a significant speed-up both in R and Python.

    fit <- lm(
      log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth, 
      data = diamonds
    )
    
    # Subset of 1018 diamonds to explain
    X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]
    
    # Exact Kernel SHAP: 61 seconds
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X, exact = TRUE)  
    )
    ks
    #          carat        cut     color     clarity         depth         table          x           y            z
    # [1,] -1.842799 0.01424231 0.1266108 -0.27033874 -0.0007084443  0.0017787647 -0.1720782 0.001330275 -0.006445693
    # [2,] -1.876709 0.03856957 0.1266546  0.03932912 -0.0004202636 -0.0004871776 -0.1739880 0.001397792 -0.006560624
    
    # Default, using an almost exact hybrid algorithm: 17 seconds
    system.time(
      ks <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)  
    )
    #          carat        cut     color     clarity         depth         table          x           y            z
    # [1,] -1.842799 0.01424231 0.1266108 -0.27033874 -0.0007084443  0.0017787647 -0.1720782 0.001330275 -0.006445693
    # [2,] -1.876709 0.03856957 0.1266546  0.03932912 -0.0004202636 -0.0004871776 -0.1739880 0.001397792 -0.006560624
    x = ["carat"] + ord + ["table", "depth", "x", "y", "z"]
    X = diamonds[x].to_numpy()
    
    # Fit model with interactions and dummy variables
    fit = ols(
      "np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color)) + table + depth + x + y + z", 
      data=diamonds
    ).fit()
    
    # Background data (120 rows)
    bg_X = X[0:len(X):450]
    
    # Define subset of 1018 diamonds to explain
    X_small = X[0:len(X):53]
    
    # Calculate KernelSHAP values: 12 minutes
    ks = KernelExplainer(
      model=lambda X: fit.predict(pd.DataFrame(X, columns=x)), 
      data = bg_X
    )
    sv = ks.shap_values(X_small)
    sv[0:2]
    # array([[-1.84279897e+00, -2.70338744e-01,  1.26610769e-01,
    #          1.42423108e-02,  1.77876470e-03, -7.08444295e-04,
    #         -1.72078182e-01,  1.33027467e-03, -6.44569296e-03],
    #        [-1.87670887e+00,  3.93291219e-02,  1.26654599e-01,
    #          3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
    #         -1.73988040e-01,  1.39779179e-03, -6.56062359e-03]])
    
    # Now, using a hybrid between exact and sampling: 5 minutes
    sv = ks.shap_values(X_small, nsamples=126)
    sv[0:2]
    # array([[-1.84279897e+00, -2.70338744e-01,  1.26610769e-01,
    #          1.42423108e-02,  1.77876470e-03, -7.08444295e-04,
    #         -1.72078182e-01,  1.33027467e-03, -6.44569296e-03],
    #        [-1.87670887e+00,  3.93291219e-02,  1.26654599e-01,
    #          3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
    #         -1.73988040e-01,  1.39779179e-03, -6.56062359e-03]])

    Again, the results are essentially the same between R and Python, but also between the hybrid algorithm and the exact algorithm. This is interesting, because the hybrid algorithm is significantly faster than the exact one.

    Wrap-Up

    • R is catching up with Python’s superb “shap” package.
    • For two non-trivial linear regressions with interactions, the “kernelshap” package in R provides the same output as Python.
    • The hybrid between exact and sampling KernelSHAP (as implemented in Python and R) offers a very good trade-off between speed and accuracy.
    • kernelshap()in R is fast!

    The Python and R codes can be found here:

    The examples were run on a Windows notebook with an Intel i7-8650U 4 core CPU.

  • Kernel SHAP

    Our last posts were on SHAP, one of the major ways to shed light into black-box Machine Learning models. SHAP values decompose predictions in a fair way into additive contributions from each feature. Decomposing many predictions and then analyzing the SHAP values gives a relatively quick and informative picture of the fitted model at hand.

    In their 2017 paper on SHAP, Scott Lundberg and Su-In Lee presented Kernel SHAP, an algorithm to calculate SHAP values for any model with numeric predictions. Compared to Monte-Carlo sampling (e.g. implemented in R package “fastshap”), Kernel SHAP is much more efficient.

    I had one problem with Kernel SHAP: I never really understood how it works!

    Then I found this article by Covert and Lee (2021). The article not only explains all the details of Kernel SHAP, it also offers an version that would iterate until convergence. As a by-product, standard errors of the SHAP values can be calculated on the fly.

    This article motivated me to implement the “kernelshap” package in R, complementing “shapr” that uses a different logic.

    The new “kernelshap” package in R

    The interface is quite simple: You need to pass three things to its main function kernelshap():

    • X: matrix/data.frame/tibble/data.table of observations to explain. Each column is a feature.
    • pred_fun: function that takes an object like X and provides one number per row.
    • bg_X: matrix/data.frame/tibble/data.table representing the background dataset used to calculate marginal expectation. Typically, between 100 and 200 rows.

    Example

    We will use Keras to build a deep learning model with 631 parameters on diamonds data. Then we decompose 500 predictions with kernelshap() and visualize them with “shapviz”.

    We will fit a Gamma regression with log link the four “C” features:

    • carat
    • color
    • clarity
    • cut
    library(tidyverse)
    library(keras)
    
    # Response and covariates
    y <- as.numeric(diamonds$price)
    X <- scale(data.matrix(diamonds[c("carat", "color", "cut", "clarity")]))
    
    # Input layer: we have 4 covariates
    input <- layer_input(shape = 4)
    
    # Two hidden layers with contracting number of nodes
    output <- input %>%
      layer_dense(units = 30, activation = "tanh") %>% 
      layer_dense(units = 15, activation = "tanh") %>% 
      layer_dense(units = 1, activation = k_exp)
    
    # Create and compile model
    nn <- keras_model(inputs = input, outputs = output)
    summary(nn)
    
    # Gamma regression loss
    loss_gamma <- function(y_true, y_pred) {
      -k_log(y_true / y_pred) + y_true / y_pred
    }
    
    nn %>% 
      compile(
        optimizer = optimizer_adam(learning_rate = 0.001),
        loss = loss_gamma
      )
    
    # Callbacks
    cb <- list(
      callback_early_stopping(patience = 20),
      callback_reduce_lr_on_plateau(patience = 5)
    )
    
    # Fit model
    history <- nn %>% 
      fit(
        x = X,
        y = y,
        epochs = 100,
        batch_size = 400, 
        validation_split = 0.2,
        callbacks = cb
      )
    
    history$metrics[c("loss", "val_loss")] %>% 
      data.frame() %>% 
      mutate(epoch = row_number()) %>% 
      filter(epoch >= 3) %>% 
      pivot_longer(cols = c("loss", "val_loss")) %>% 
    ggplot(aes(x = epoch, y = value, group = name, color = name)) +
      geom_line(size = 1.4)

    Interpretation via KernelSHAP

    In order to peak into the fitted model, we apply the Kernel SHAP algorithm to decompose 500 randomly selected diamond predictions. We use the same subset as background dataset required by the Kernel SHAP algorithm.

    Afterwards, we will study

    • Some SHAP values and their standard errors
    • One waterfall plot
    • A beeswarm summary plot to get a rough picture of variable importance and the direction of the feature effects
    • A SHAP dependence plot for carat
    # Interpretation on 500 randomly selected diamonds
    library(kernelshap)
    library(shapviz)
    
    sample(1)
    ind <- sample(nrow(X), 500)
    
    dia_small <- X[ind, ]
    
    # 77 seconds
    system.time(
      ks <- kernelshap(
        dia_small, 
        pred_fun = function(X) as.numeric(predict(nn, X, batch_size = nrow(X))), 
        bg_X = dia_small
      )
    )
    ks
    
    # Output
    # 'kernelshap' object representing 
    # - SHAP matrix of dimension 500 x 4 
    # - feature data.frame/matrix of dimension 500 x 4 
    # - baseline value of 3744.153
    # 
    # SHAP values of first 2 observations:
    #         carat     color       cut   clarity
    # [1,] -110.738 -240.2758  5.254733 -720.3610
    # [2,] 2379.065  263.3112 56.413680  452.3044
    # 
    # Corresponding standard errors:
    #         carat      color       cut  clarity
    # [1,] 2.064393 0.05113337 0.1374942 2.150754
    # [2,] 2.614281 0.84934844 0.9373701 0.827563
    
    sv <- shapviz(ks, X = diamonds[ind, x])
    sv_waterfall(sv, 1)
    sv_importance(sv, "both")
    sv_dependence(sv, "carat", "auto")

    Note the small standard errors of the SHAP values of the first two diamonds. They are only approximate because the background data is only a sample from an unknown population. Still, they give a good impression on the stability of the results.

    The waterfall plot shows a diamond with not super nice clarity and color, pulling down the value of this diamond. Note that, even if the model is working with scaled numeric feature values, the plot shows the original feature values.

    SHAP waterfall plot of one diamond. Note its bad clarity.

    The SHAP summary plot shows that “carat” is, unsurprisingly, the most important variable and that high carat mean high value. “cut” is not very important, except if it is extremely bad.

    SHAP summary plot with bars representing average absolute values as measure of importance.

    Our last plot is a SHAP dependence plot for “carat”: the effect makes sense, and we can spot some interaction with color. For worse colors (H-J), the effect of carat is a bit less strong as for the very white diamonds.

    Dependence plot for “carat”

    Short wrap-up

    • Standard Kernel SHAP in R, yeahhhhh 🙂
    • The Github version is relatively fast, so you can even decompose 500 observations of a deep learning model within 1-2 minutes.

    The complete R script can be found here.

  • shapviz goes H2O

    In a recent post, I introduced the initial version of the “shapviz” package. Its motto: do one thing, but do it well: visualize SHAP values.

    The initial community feedback was very positive, and a couple of things have been improved in version 0.2.0. Here the main changes:

    1. “shapviz” now works with tree-based models of the h2o package in R.
    2. Additionally, it wraps the shapr package, which implements an improved version of Kernel SHAP taking into account feature dependence.
    3. A simple interface to collapse SHAP values of dummy variables was added.
    4. The default importance plot is now a bar plot, instead of the (slower) beeswarm plot. In later releases, the latter might be moved to a separate function sv_summary() for consistency with other packages.
    5. Importance plot and dependence plot now work neatly with ggplotly(). The other plot types cannot be translated with ggplotly() because they use geoms from outside ggplot. At least I do not know how to do this…

    Example

    Let’s build an H2O gradient boosted trees model to explain diamond prices. Then, we explain the model with our “shapviz” package. Note that H2O itself also offers some SHAP plots. “shapviz” is directly applied to the fitted H2O model. This means you don’t have to write a single superfluous line of code.

    library(shapviz)
    library(tidyverse)
    library(h2o)
    
    h2o.init()
    
    set.seed(1)
    
    # Get rid of that darn ordinals
    ord <- c("clarity", "cut", "color")
    diamonds[, ord] <- lapply(diamonds[, ord], factor, ordered = FALSE)
    
    # Minimally tuned GBM with 260 trees, determined by early-stopping with CV
    dia_h2o <- as.h2o(diamonds)
    fit <- h2o.gbm(
      c("carat", "clarity", "color", "cut"),
      y = "price",
      training_frame = dia_h2o,
      nfolds = 5,
      learn_rate = 0.05,
      max_depth = 4,
      ntrees = 10000,
      stopping_rounds = 10,
      score_each_iteration = TRUE
    )
    fit
    
    # SHAP analysis on about 2000 diamonds
    X_small <- diamonds %>%
      filter(carat <= 2.5) %>%
      sample_n(2000) %>%
      as.h2o()
    
    shp <- shapviz(fit, X_pred = X_small)
    
    sv_importance(shp, show_numbers = TRUE)
    sv_importance(shp, show_numbers = TRUE, kind = "bee")
    sv_dependence(shp, "color", "auto", alpha = 0.5)
    sv_force(shp, row_id = 1)
    sv_waterfall(shp, row_id = 1)

    Summary and importance plots

    The SHAP importance and SHAP summary plots clearly show that carat is the most important variable. On average, it impacts the prediction by 3247 USD. The effect of “cut” is much smaller. Its impact on the predictions, on average, is plus or minus 112 USD.

    SHAP summary plot
    SHAP importance plot

    SHAP dependence plot

    The SHAP dependence plot shows the effect of “color” on the prediction: The better the color (close to “D”), the higher the price. Using a correlation based heuristic, the plot selected carat on the color scale to show that the color effect is hightly influenced by carat in the sense that the impact of color increases with larger diamond weight. This clearly makes sense!

    Dependence plot for “color”

    Waterfall and force plot

    Finally, the waterfall and force plots show how a single prediction is decomposed into contributions from each feature. While this does not tell much about the model itself, it might be helpful to explain what SHAP values are and to debug strange predictions.

    Waterfall plot
    Force plot

    Short wrap-up

    • Combining “shapviz” and H2O is fun. Okay, that one was subjective :-).
    • Good visualization of ML models is extremely helpful and reassuring.

    The complete R script can be found here.

  • Visualize SHAP Values without Tears

    SHAP (SHapley Additive exPlanations, Lundberg and Lee, 2017) is an ingenious way to study black box models. SHAP values decompose – as fair as possible – predictions into additive feature contributions.

    When it comes to SHAP, the Python implementation is the de-facto standard. It not only offers many SHAP algorithms, but also provides beautiful plots. In R, the situation is a bit more confusing. Different packages contain implementations of SHAP algorithms, e.g.,

    some of which with great visualizations. Plus there is SHAPforxgboost (see my recent post), originally designed to visualize the results of SHAP values calculated from XGBoost, but it can also be used more generally by now.

    The shapviz package

    In order to entangle calculation from visualization, the shapviz package was designed. It solely focuses on visualization of SHAP values. Closely following its README, it currently provides these plots:

    • sv_waterfall(): Waterfall plots to study single predictions.
    • sv_force(): Force plots as an alternative to waterfall plots.
    • sv_importance(): Importance plots (bar and/or beeswarm plots) to study variable importance.
    • sv_dependence(): Dependence plots to study feature effects (optionally colored by heuristically strongest interacting feature).

    They require a “shapviz” object, which is built from two things only:

    1. S: Matrix of SHAP values
    2. X: Dataset with corresponding feature values

    Furthermore, a “baseline” can be passed to represent an average prediction on the scale of the SHAP values.

    A key feature of the “shapviz” package is that X is used for visualization only. Thus it is perfectly fine to use factor variables, even if the underlying model would not accept these.

    To further simplify the use of shapviz, direct connectors to the packages

    are available.

    Installation

    The package shapviz can be installed from CRAN or Github:

    • devtools::install_github("shapviz")
    • devtools::install_github("mayer79/shapviz")

    Example

    Shiny diamonds… let’s model their prices by four “c” variables with XGBoost, and create an explanation dataset with 2000 randomly picked diamonds.

    library(shapviz)
    library(ggplot2)
    library(xgboost)
    
    set.seed(3653)
    
    X <- diamonds[c("carat", "cut", "color", "clarity")]
    dtrain <- xgb.DMatrix(data.matrix(X), label = diamonds$price)
    
    fit <- xgb.train(
      params = list(learning_rate = 0.1, objective = "reg:squarederror"), 
      data = dtrain,
      nrounds = 65L
    )
    
    # Explanation dataset
    X_small <- X[sample(nrow(X), 2000L), ]

    Create “shapviz” object

    One line of code creates a shapviz object. It contains SHAP values and feature values for the set of observations we are interested in. Note again that X is solely used as explanation dataset, not for calculating SHAP values.

    In this example we construct the shapviz object directly from the fitted XGBoost model. Thus we also need to pass a corresponding prediction dataset X_pred used for calculating SHAP values by XGBoost.

    shp <- shapviz(fit, X_pred = data.matrix(X_small), X = X_small)

    Explaining one single prediction

    Let’s start by explaining a single prediction by a waterfall plot or, alternatively, a force plot.

    # Two types of visualizations
    sv_waterfall(shp, row_id = 1)
    sv_force(shp, row_id = 1
    Waterfall plot

    Factor/character variables are kept as they are, even if the underlying XGBoost model required them to be integer encoded.

    Force plot

    Explaining the model as a whole

    We have decomposed 2000 predictions, not just one. This allows us to study variable importance at a global model level by studying average absolute SHAP values as a bar plot or by looking at beeswarm plots of SHAP values.

    # Three types of variable importance plots
    sv_importance(shp)
    sv_importance(shp, kind = "bar")
    sv_importance(shp, kind = "both", alpha = 0.2, width = 0.2)
    Beeswarm plot
    Bar plot
    Beeswarm plot overlaid with bar plot

    A scatterplot of SHAP values of a feature like color against its observed values gives a great impression on the feature effect on the response. Vertical scatter gives additional info on interaction effects. shapviz offers a heuristic to pick another feature on the color scale with potential strongest interaction.

    sv_dependence(shp, v = "color", "auto")
    Dependence plot with automatic interaction colorization

    Summary

    • The “shapviz” has a single purpose: making SHAP plots.
    • Its interface is optimized for existing SHAP crunching packages and can easily be used in future packages as well.
    • All plots are highly customizable. Furthermore, they are all written with ggplot and allow corresponding modifications.

    The complete R script can be found here.

    References

    Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30 (2017).

  • Let the flashlight shine with plotly

    There are different R packages devoted to model agnostic interpretability, DALEX and iml being among the best known. In 2019, I added flashlight 

    logo.png

    for a couple of reasons:

    1. Its explainers work with case weights.
    2. Multiple explainers can be combined to a multi-explainer.
    3. Stratified calculation is possible.

    Since almost all plots in flashlight are constructed with ggplot, it is super easy to turn them into interactive plotly objects: just add a simple ggplotly() to the end of the call.

    However… it is not straightforward to show interactive plots in a blog! Thus, we show only screenshots of the resulting plots here and refer to the complete HTML report here: https://mayer79.github.io/flashlight_plotly/flashlight_plotly.html

    We will use a sweet dataset with more than 20’000 houses to model house prices by a set of derived features such as the logarithmic living area. The location will be represented by the postal code.

    Data preparation

    We first load the data and prepare some of the columns for modeling. Furthermore, we specify the set of features and the response.

    library(dplyr)
    library(flashlight)
    library(plotly)
    library(ranger)
    library(lme4)
    library(moderndive)
    library(splitTools)
    library(MetricsWeighted)
    
    set.seed(4933)
    
    data("house_prices")
    
    prep <- house_prices %>% 
      mutate(
        log_price = log(price),
        log_sqft_living = log(sqft_living),
        log_sqft_lot = log(sqft_lot),
        log_sqft_basement = log1p(sqft_basement),
        year = as.numeric(format(date, '%Y')),
        age = year - yr_built
      )
    
    x <- c(
      "year", "age", "log_sqft_living", "log_sqft_lot", 
      "bedrooms", "bathrooms", "log_sqft_basement", 
      "condition", "waterfront", "zipcode"
    )
    
    y <- "log_price"
    
    head(prep[c(y, x)])
    
    ## # A tibble: 6 x 11
    ##   log_price  year   age log_sqft_living log_sqft_lot bedrooms bathrooms
    ##       <dbl> <dbl> <dbl>           <dbl>        <dbl>    <int>     <dbl>
    ## 1      12.3  2014    59            7.07         8.64        3      1   
    ## 2      13.2  2014    63            7.85         8.89        3      2.25
    ## 3      12.1  2015    82            6.65         9.21        2      1   
    ## 4      13.3  2014    49            7.58         8.52        4      3   
    ## 5      13.1  2015    28            7.43         9.00        3      2   
    ## 6      14.0  2014    13            8.60        11.5         4      4.5 
    ## # ... with 4 more variables: log_sqft_basement <dbl>, condition <fct>,
    ## #   waterfront <lgl>, zipcode <fct>

    Train / test split

    Then, we split the dataset into 80% training and 20% test rows, stratified on the (binned) response log_price.

    idx <- partition(prep[[y]], c(train = 0.8, test = 0.2), type = "stratified")
    
    train <- prep[idx$train, ]
    test <- prep[idx$test, ]

    Models

    We fit two models:

    1. A linear mixed model with random postal code effect.
    2. A random forest with 500 trees.
    # Mixed-effects model
    fit_lmer <- lmer(
      update(reformulate(x, "log_price"), . ~ . - zipcode + (1 | zipcode)),
      data = train
    )
    
    # Random forest
    fit_rf <- ranger(
      reformulate(x, "log_price"),
      always.split.variables = "zipcode",
      data = train
    )
    cat("R-squared OOB:", fit_rf$r.squared)
    ## R-squared OOB: 0.8463311

    Model inspection

    Now, we are ready to inspect our two models regarding performance, variable importance, and effects.

    Set up explainers

    First, we pack all model dependent information into flashlights (the explainer objects) and combine them to a multiflashlight. As evaluation dataset, we pass the test data. This ensures that interpretability tools using the response (e.g., performance measures and permutation importance) are not being biased by overfitting.

    fl_lmer <- flashlight(model = fit_lmer, label = "LMER")
    fl_rf <- flashlight(
      model = fit_rf,
      label = "RF",
      predict_function = function(mod, X) predict(mod, X)$predictions
    )
    fls <- multiflashlight(
      list(fl_lmer, fl_rf),
      y = "log_price",
      data = test,
      metrics = list(RMSE = rmse, `R-squared` = r_squared)
    )

    Model performance

    Let’s evaluate model RMSE and R-squared on the hold-out dataset. Here, the mixed-effects model performs a tiny little bit better than the random forest:

    (light_performance(fls) %>%
      plot(fill = "darkred") +
        labs(title = "Model performance", x = element_blank())) %>%
      ggplotly()
    Model performance (png)

    Permutation importance

    Next, we inspect the variable strength based on permutation importance. It shows by how much the RMSE is being increased when shuffling a variable before prediction. The results are quite similar between the two models.

    (light_importance(fls, v = x) %>%
        plot(fill = "darkred") +
        labs(title = "Permutation importance", y = "Drop in RMSE")) %>%
      ggplotly()
    Variable importance (png)

    ICE plot

    To get an impression of the effect of the living area, we select 200 observations and profile their predictions with increasing (log) living area, keeping everything else fixed (Ceteris Paribus). These ICE (individual conditional expectation) plots are vertically centered in order to highlight potential interaction effects. If all curves coincide, there are no interaction effects and we can say that the effect of the feature is modelled in an additive way (no surprise for the additive linear mixed-effects model).

    (light_ice(fls, v = "log_sqft_living", n_max = 200, center = "middle") %>%
        plot(alpha = 0.05, color = "darkred") +
        labs(title = "Centered ICE plot", y = "log_price (shifted)")) %>%
      ggplotly()

    Partial dependence plots

    Averaging many uncentered ICE curves provides the famous partial dependence plot, introduced in Friedman’s seminal paper on gradient boosting machines (2001).

    (light_profile(fls, v = "log_sqft_living", n_bins = 21) %>%
        plot(rotate_x = FALSE) +
        labs(title = "Partial dependence plot", y = y) +
        scale_colour_viridis_d(begin = 0.2, end = 0.8)) %>%
      ggplotly()
    Partial dependence plots (png)

    Multiple effects visualized together

    The last figure extends the partial dependence plot with three additional curves, all evaluated on the hold-out dataset:

    • Average observed values
    • Average predictions
    • ALE plot (“accumulated local effects”, an alternative to partial dependence plots with relaxed Ceteris Paribus assumption)
    (light_effects(fls, v = "log_sqft_living", n_bins = 21) %>%
        plot(use = "all")  +
        labs(title = "Different effect estimates", y = y) +
        scale_colour_viridis_d(begin = 0.2, end = 0.8)) %>%
      ggplotly()
    Multiple effects together (png)

    Conclusion

    Combining flashlight with plotly works well and provides nice, interactive plots. Using rmarkdown, an analysis like this look quite neat if shipped as an HTML like this one here: https://mayer79.github.io/flashlight_plotly/flashlight_plotly.html

    The rmarkdown script can be found here on github.

  • DuckDB: Quacking SQL

    Lost in Translation between R and Python 8

    This is the next article in our series “Lost in Translation between R and Python”. The aim of this series is to provide high-quality R and Python 3 code to achieve some non-trivial tasks. If you are to learn R, check out the R tab below. Similarly, if you are to learn Python, the Python tab will be your friend.

    DuckDB

    DuckDB is a fantastic in-process SQL database management system written completely in C++. Check its official documentation and other blogposts like this to get a feeling of its superpowers. It is getting better and better!

    Some of the highlights:

    • Easy installation in R and Python, made possible via language bindings.
    • Multiprocessing and fast.
    • Allows to work with data bigger than RAM.
    • Can fire SQL queries on R and Pandas tables.
    • Can fire SQL queries on (multiple!) csv and/or Parquet files.
    • Quacks Apache Arrow.

    Installation

    DuckDB is super easy to install:

    • R: install.packages("duckdb")
    • Python: pip install duckdb

    Additional packages required to run the code of this post are indicated in the code.

    A first query

    Let’s start by loading a dataset, initializing DuckDB and running a simple query.

    The dataset we use here contains information on over 20,000 sold houses in Kings County. Along with the sale price, different features describe the size and location of the properties. The dataset is available on OpenML.org with ID 42092.

    library(OpenML)
    library(duckdb)
    library(tidyverse)
    
    # Load data
    df <- getOMLDataSet(data.id = 42092)$data
    
    # Initialize duckdb, register df and materialize first query
    con = dbConnect(duckdb())
    duckdb_register(con, name = "df", df = df)
    con %>% 
      dbSendQuery("SELECT * FROM df limit 5") %>% 
      dbFetch()
    import duckdb
    import pandas as pd
    from sklearn.datasets import fetch_openml
    
    # Load data
    df = fetch_openml(data_id=42092, as_frame=True)["frame"]
    
    # Initialize duckdb, register df and fire first query
    # If out-of-RAM: duckdb.connect("py.duckdb", config={"temp_directory": "a_directory"})
    con = duckdb.connect()
    con.register("df", df)
    con.execute("SELECT * FROM df limit 5").fetchdf()
    Result of first query (from R)

    Average price per grade

    If you like SQL, then you can do your data preprocessing and simple analyses with DuckDB. Here, we calculate the average house price per online grade (the higher the grade, the better the house).

    query <- 
      "
      SELECT AVG(price) avg_price, grade 
      FROM df 
      GROUP BY grade
      ORDER BY grade
      "
    avg <- con %>% 
      dbSendQuery(query) %>% 
      dbFetch()
    
    avg
    
    # Average price per grade
    query = """
      SELECT AVG(price) avg_price, grade 
      FROM df 
      GROUP BY grade
      ORDER BY grade
      """
    avg = con.execute(query).fetchdf()
    avg
    R output

    Highlight: queries to files

    The last query will be applied directly to files on disk. To demonstrate this fantastic feature, we first save “df” as a parquet file and “avg” as a csv file.

    write_parquet(df, "housing.parquet")
    write.csv(avg, "housing_avg.csv", row.names = FALSE)
    
    # Save df and avg to different file types
    df.to_parquet("housing.parquet")  # pyarrow=7
    avg.to_csv("housing_avg.csv", index=False)

    Let’s load some columns of “housing.parquet” data, but only rows with grades having an average price of one million USD. Agreed, that query does not make too much sense but I hope you get the idea…😃

    # "Complex" query
    query2 <- "
      SELECT price, sqft_living, A.grade, avg_price
      FROM 'housing.parquet' A
      LEFT JOIN 'housing_avg.csv' B
      ON A.grade = B.grade
      WHERE B.avg_price > 1000000
      "
    
    expensive_grades <- con %>% 
      dbSendQuery(query2) %>% 
      dbFetch()
    
    head(expensive_grades)
    
    # dbDisconnect(con)
    # Complex query
    query2 = """
      SELECT price, sqft_living, A.grade, avg_price
      FROM 'housing.parquet' A
      LEFT JOIN 'housing_avg.csv' B
      ON A.grade = B.grade
      WHERE B.avg_price > 1000000
      """
    expensive_grades = con.execute(query2).fetchdf()
    expensive_grades
    
    # con.close()
    R output

    Last words

    • DuckDB is cool!
    • If you have strong SQL skills but do not know R or Python so well, this is a great way to get used to those programming languages.
    • If you are unfamiliar to SQL but like R and/or Python, you can use DuckDB for a while and end up being an SQL addict.
    • If your analysis involves combining many large files during preprocessing, then you can try the trick shown in the last example of this post.

    The Python notebook and R code can be found at:

  • Avoid loops in R! Really?

    It must have been around the year 2000, when I wrote my first snipped of SPLUS/R code. One thing I’ve learned back then:

    Loops are slow. Replace them with

    1. vectorized calculations or
    2. if vectorization is not possible, use sapply() et al.

    Since then, the R core team and the community has invested tons of time to improve R and also to make it faster. There are things like RCPP and parallel computing to speed up loops.

    But what still relatively few R users know: loops are not that slow anymore. We want to demonstrate this using two examples.

    Example 1: sqrt()

    We use three ways to calculate the square root of a vector of random numbers:

    1. Vectorized calculation. This will be the way to go because it is internally optimized in C.
    2. A loop. This must be super slow for large vectors.
    3. vapply() (as safe alternative to sapply).

    The three approaches are then compared via bench::mark() regarding their speed for different numbers n of vector lengths. The results are then compared first regarding absolute median times, and secondly (using an independent run), on a relative scale (1 is the vectorized approach).

    library(tidyverse)
    library(bench)
    
    # Calculate square root for each element in loop
    sqrt_loop <- function(x) {
      out <- numeric(length(x))
      for (i in seq_along(x)) {
        out[i] <- sqrt(x[i])
      }
      out
    }
    
    # Example
    sqrt_loop(1:4) # 1.000000 1.414214 1.732051 2.000000
    
    # Compare its performance with two alternatives
    sqrt_benchmark <- function(n) {
      x <- rexp(n)
      mark(
        vectorized = sqrt(x),
        loop = sqrt_loop(x),
        vapply = vapply(x, sqrt, FUN.VALUE = 0.0),
        # relative = TRUE
      )
    }
    
    # Combine results of multiple benchmarks and plot results
    multiple_benchmarks <- function(one_bench, N) {
      res <- vector("list", length(N))
      for (i in seq_along(N)) {
        res[[i]] <- one_bench(N[i]) %>% 
          mutate(n = N[i], expression = names(expression))
      }
      
      ggplot(bind_rows(res), aes(n, median, color = expression)) +
        geom_point(size = 3) +
        geom_line(size = 1) +
        scale_x_log10() +
        ggtitle(deparse1(substitute(one_bench))) +
        theme(legend.position = c(0.8, 0.15))
    }
    
    # Apply simulation
    multiple_benchmarks(sqrt_benchmark, N = 10^seq(3, 6, 0.25))

    Absolute timings

    Absolute median times on the “sqrt()” task

    Relative timings (using a second run)

    Relative median times of a separate run on the “sqrt()” task

    We see:

    • Run times increase quite linearly with vector size.
    • Vectorization is more than ten times faster than the naive loop.
    • Most strikingly, vapply() is much slower than the naive loop. Would you have thought this?

    Example 2: paste()

    For the second example, we use a less simple function, namely

    paste(“Number”, prettyNum(x, digits = 5))

    What will our three approaches (vectorized, naive loop, vapply) show on this task?

    pretty_paste <- function(x) {
      paste("Number", prettyNum(x, digits = 5))
    }
    
    # Example
    pretty_paste(pi) # "Number 3.1416"
    
    # Again, call pretty_paste() for each element in a loop
    paste_loop <- function(x) {
      out <- character(length(x))
      for (i in seq_along(x)) {
        out[i] <- pretty_paste(x[i])
      }
      out
    }
    
    # Compare its performance with two alternatives
    paste_benchmark <- function(n) {
      x <- rexp(n)
      mark(
        vectorized = pretty_paste(x),
        loop = paste_loop(x),
        vapply = vapply(x, pretty_paste, FUN.VALUE = ""),
        # relative = TRUE
      )
    }
    
    multiple_benchmarks(paste_benchmark, N = 10^seq(3, 5, 0.25))

    Absolute timings

    Absolute median times on the “paste()” task

    Relative timings (using a second run)

    Relative median times of a separate run on the “paste()” task
    • In contrast to the first example, vapply() is now as fast as the naive loop.
    • The time advantage of the vectorized approach is much less impressive. The loop takes in median only 50% longer.

    Conclusion

    1. Vectorization is fast and easy to read. If available, use this. No surprise.
    2. If you use vapply/sapply/lapply, do it for the style, not for the speed. In some cases, the loop will be faster. And, depending on the situation and the audience, a loop might actually be even easier to read.

    The code can be found on github.

    The runs have been made on a Windows 11 system with a four core Intel(R) Core(TM) i7-8650U CPU @ 1.90GHz processor.