{"id":939,"date":"2022-08-12T18:37:13","date_gmt":"2022-08-12T16:37:13","guid":{"rendered":"https:\/\/lorentzen.ch\/?p=939"},"modified":"2022-08-12T18:37:13","modified_gmt":"2022-08-12T16:37:13","slug":"kernel-shap","status":"publish","type":"post","link":"https:\/\/lorentzen.ch\/index.php\/2022\/08\/12\/kernel-shap\/","title":{"rendered":"Kernel SHAP"},"content":{"rendered":"\n<p>Our <a href=\"https:\/\/lorentzen.ch\/index.php\/2022\/07\/11\/shapviz-goes-h2o\/\">last post<\/a>s were on SHAP, one of the major ways to shed light into black-box Machine Learning models. SHAP values decompose predictions in a fair way into additive contributions from each feature. Decomposing many predictions and then analyzing the SHAP values gives a relatively quick and informative picture of the fitted model at hand.<\/p>\n\n\n\n<p>In their 2017 paper on SHAP, Scott Lundberg and Su-In Lee presented Kernel SHAP, an algorithm to calculate SHAP values for any model with numeric predictions. Compared to Monte-Carlo sampling (e.g. implemented in R package <a href=\"https:\/\/CRAN.R-project.org\/package=fastshap\">&#8220;fastshap&#8221;<\/a>), Kernel SHAP is much more efficient. <\/p>\n\n\n\n<p><strong>I had one problem with Kernel SHAP: I never really understood how it works!<\/strong><\/p>\n\n\n\n<p>Then I found <a href=\"https:\/\/proceedings.mlr.press\/v130\/covert21a.html\">this article<\/a> by Covert and Lee (2021). The article not only explains all the details of Kernel SHAP, it also offers an version that would iterate until convergence. As a by-product, standard errors of the SHAP values can be calculated on the fly.<\/p>\n\n\n\n<p>This article motivated me to implement the &#8220;kernelshap&#8221; package in R, complementing <a href=\"https:\/\/CRAN.R-project.org\/package=shapr\">&#8220;shapr&#8221;<\/a> that uses a different logic. <\/p>\n\n\n\n<figure class=\"wp-block-image size-full is-resized\"><img loading=\"lazy\" decoding=\"async\" src=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/logo_kernelshap.png\" alt=\"\" class=\"wp-image-940\" width=\"271\" height=\"314\" srcset=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/logo_kernelshap.png 518w, https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/logo_kernelshap-259x300.png 259w\" sizes=\"auto, (max-width: 271px) 100vw, 271px\" \/><figcaption>The new &#8220;kernelshap&#8221; package in R<\/figcaption><\/figure>\n\n\n\n<ul class=\"wp-block-list\"><li>Bleeding edge version 0.1.1 on Github: https:\/\/github.com\/mayer79\/kernelshap<\/li><li>Slower version 0.1.0 on CRAN: <a href=\"https:\/\/cran.r-project.org\/package=shapr\"><samp>https:\/\/CRAN.R-project.org\/package=shapr<\/samp><\/a><\/li><\/ul>\n\n\n\n<p>The interface is quite simple: You need to pass three things to its main function <code>kernelshap()<\/code>:<\/p>\n\n\n\n<ul class=\"wp-block-list\"><li><code>X<\/code>: matrix\/data.frame\/tibble\/data.table of observations to explain. Each column is a feature.<\/li><li>p<code>red_fun<\/code>: function that takes an object like <code>X<\/code> and provides one number per row.<\/li><li><code>bg_X<\/code>: matrix\/data.frame\/tibble\/data.table representing the background dataset used to calculate marginal expectation. Typically, between 100 and 200 rows. <\/li><\/ul>\n\n\n\n<h3 class=\"wp-block-heading\">Example<\/h3>\n\n\n\n<p>We will use Keras to build a deep learning model with 631 parameters on diamonds data. Then we decompose 500 predictions with <code>kernelshap()<\/code> and visualize them with <a href=\"https:\/\/CRAN.R-project.org\/package=shapviz\">&#8220;shapviz&#8221;<\/a>.<\/p>\n\n\n\n<p>We will fit a Gamma regression with log link the four &#8220;C&#8221; features:<\/p>\n\n\n\n<ul class=\"wp-block-list\"><li>carat<\/li><li>color<\/li><li>clarity<\/li><li>cut<\/li><\/ul>\n\n\n<div class=\"wp-block-ub-tabbed-content wp-block-ub-tabbed-content-holder wp-block-ub-tabbed-content-horizontal-holder-mobile wp-block-ub-tabbed-content-horizontal-holder-tablet\" id=\"ub-tabbed-content-e40d45f9-dac6-466d-9e6b-1a4367768997\" style=\"\">\n\t\t\t<div class=\"wp-block-ub-tabbed-content-tab-holder horizontal-tab-width-mobile horizontal-tab-width-tablet\">\n\t\t\t\t<div role=\"tablist\" class=\"wp-block-ub-tabbed-content-tabs-title wp-block-ub-tabbed-content-tabs-title-mobile-horizontal-tab wp-block-ub-tabbed-content-tabs-title-tablet-horizontal-tab\" style=\"justify-content: flex-start; \"><div role=\"tab\" id=\"ub-tabbed-content-e40d45f9-dac6-466d-9e6b-1a4367768997-tab-0\" aria-controls=\"ub-tabbed-content-e40d45f9-dac6-466d-9e6b-1a4367768997-panel-0\" aria-selected=\"true\" class=\"wp-block-ub-tabbed-content-tab-title-wrap active\" style=\"--ub-tabbed-title-background-color: #6d6d6d; --ub-tabbed-active-title-color: inherit; --ub-tabbed-active-title-background-color: #6d6d6d; text-align: center; \" tabindex=\"-1\">\n\t\t\t\t<div class=\"wp-block-ub-tabbed-content-tab-title\">R<\/div>\n\t\t\t<\/div><\/div>\n\t\t\t<\/div>\n\t\t\t<div class=\"wp-block-ub-tabbed-content-tabs-content\" style=\"\"><div role=\"tabpanel\" class=\"wp-block-ub-tabbed-content-tab-content-wrap active\" id=\"ub-tabbed-content-e40d45f9-dac6-466d-9e6b-1a4367768997-panel-0\" aria-labelledby=\"ub-tabbed-content-e40d45f9-dac6-466d-9e6b-1a4367768997-tab-0\" tabindex=\"0\">\n\n<div class=\"wp-block-codemirror-blocks-code-block code-block\"><pre class=\"CodeMirror\" data-setting='{\"showPanel\":true,\"languageLabel\":\"language\",\"fullScreenButton\":true,\"copyButton\":true,\"mode\":\"r\",\"mime\":\"text\/x-rsrc\",\"theme\":\"material\",\"lineNumbers\":false,\"styleActiveLine\":false,\"lineWrapping\":false,\"readOnly\":true,\"fileName\":\"\",\"language\":\"R\",\"maxHeight\":\"400px\",\"modeName\":\"r\"}'>library(tidyverse)\nlibrary(keras)\n\n# Response and covariates\ny &lt;- as.numeric(diamonds$price)\nX &lt;- scale(data.matrix(diamonds[c(\"carat\", \"color\", \"cut\", \"clarity\")]))\n\n# Input layer: we have 4 covariates\ninput &lt;- layer_input(shape = 4)\n\n# Two hidden layers with contracting number of nodes\noutput &lt;- input %&gt;%\n  layer_dense(units = 30, activation = \"tanh\") %&gt;% \n  layer_dense(units = 15, activation = \"tanh\") %&gt;% \n  layer_dense(units = 1, activation = k_exp)\n\n# Create and compile model\nnn &lt;- keras_model(inputs = input, outputs = output)\nsummary(nn)\n\n# Gamma regression loss\nloss_gamma &lt;- function(y_true, y_pred) {\n  -k_log(y_true \/ y_pred) + y_true \/ y_pred\n}\n\nnn %&gt;% \n  compile(\n    optimizer = optimizer_adam(learning_rate = 0.001),\n    loss = loss_gamma\n  )\n\n# Callbacks\ncb &lt;- list(\n  callback_early_stopping(patience = 20),\n  callback_reduce_lr_on_plateau(patience = 5)\n)\n\n# Fit model\nhistory &lt;- nn %&gt;% \n  fit(\n    x = X,\n    y = y,\n    epochs = 100,\n    batch_size = 400, \n    validation_split = 0.2,\n    callbacks = cb\n  )\n\nhistory$metrics[c(\"loss\", \"val_loss\")] %&gt;% \n  data.frame() %&gt;% \n  mutate(epoch = row_number()) %&gt;% \n  filter(epoch &gt;= 3) %&gt;% \n  pivot_longer(cols = c(\"loss\", \"val_loss\")) %&gt;% \nggplot(aes(x = epoch, y = value, group = name, color = name)) +\n  geom_line(size = 1.4)<\/pre><\/div>\n\n<\/div><\/div>\n\t\t<\/div>\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"683\" height=\"569\" src=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/Rplot.png\" alt=\"\" class=\"wp-image-941\" srcset=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/Rplot.png 683w, https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/Rplot-300x250.png 300w\" sizes=\"auto, (max-width: 683px) 100vw, 683px\" \/><\/figure>\n\n\n\n<h3 class=\"wp-block-heading\">Interpretation via KernelSHAP<\/h3>\n\n\n\n<p>In order to peak into the fitted model, we apply the Kernel SHAP algorithm to decompose 500 randomly selected diamond predictions. We use the same subset as background dataset required by the Kernel SHAP algorithm. <\/p>\n\n\n\n<p>Afterwards, we will study<\/p>\n\n\n\n<ul class=\"wp-block-list\"><li>Some SHAP values and their standard errors<\/li><li>One waterfall plot<\/li><li>A beeswarm summary plot to get a rough picture of variable importance and the direction of the feature effects<\/li><li>A SHAP dependence plot for carat<\/li><\/ul>\n\n\n<div class=\"wp-block-ub-tabbed-content wp-block-ub-tabbed-content-holder wp-block-ub-tabbed-content-horizontal-holder-mobile wp-block-ub-tabbed-content-horizontal-holder-tablet\" id=\"ub-tabbed-content-8602845d-735d-4333-8b30-1e733787a6db\" style=\"\">\n\t\t\t<div class=\"wp-block-ub-tabbed-content-tab-holder horizontal-tab-width-mobile horizontal-tab-width-tablet\">\n\t\t\t\t<div role=\"tablist\" class=\"wp-block-ub-tabbed-content-tabs-title wp-block-ub-tabbed-content-tabs-title-mobile-horizontal-tab wp-block-ub-tabbed-content-tabs-title-tablet-horizontal-tab\" style=\"justify-content: flex-start; \"><div role=\"tab\" id=\"ub-tabbed-content-8602845d-735d-4333-8b30-1e733787a6db-tab-0\" aria-controls=\"ub-tabbed-content-8602845d-735d-4333-8b30-1e733787a6db-panel-0\" aria-selected=\"true\" class=\"wp-block-ub-tabbed-content-tab-title-wrap active\" style=\"--ub-tabbed-title-background-color: #6d6d6d; --ub-tabbed-active-title-color: inherit; --ub-tabbed-active-title-background-color: #6d6d6d; text-align: center; \" tabindex=\"-1\">\n\t\t\t\t<div class=\"wp-block-ub-tabbed-content-tab-title\">R<\/div>\n\t\t\t<\/div><\/div>\n\t\t\t<\/div>\n\t\t\t<div class=\"wp-block-ub-tabbed-content-tabs-content\" style=\"\"><div role=\"tabpanel\" class=\"wp-block-ub-tabbed-content-tab-content-wrap active\" id=\"ub-tabbed-content-8602845d-735d-4333-8b30-1e733787a6db-panel-0\" aria-labelledby=\"ub-tabbed-content-8602845d-735d-4333-8b30-1e733787a6db-tab-0\" tabindex=\"0\">\n\n<div class=\"wp-block-codemirror-blocks-code-block code-block\"><pre class=\"CodeMirror\" data-setting='{\"showPanel\":true,\"languageLabel\":\"language\",\"fullScreenButton\":true,\"copyButton\":true,\"mode\":\"r\",\"mime\":\"text\/x-rsrc\",\"theme\":\"material\",\"lineNumbers\":false,\"styleActiveLine\":false,\"lineWrapping\":false,\"readOnly\":true,\"fileName\":\"\",\"language\":\"R\",\"maxHeight\":\"400px\",\"modeName\":\"r\"}'># Interpretation on 500 randomly selected diamonds\nlibrary(kernelshap)\nlibrary(shapviz)\n\nsample(1)\nind &lt;- sample(nrow(X), 500)\n\ndia_small &lt;- X[ind, ]\n\n# 77 seconds\nsystem.time(\n  ks &lt;- kernelshap(\n    dia_small, \n    pred_fun = function(X) as.numeric(predict(nn, X, batch_size = nrow(X))), \n    bg_X = dia_small\n  )\n)\nks\n\n# Output\n# 'kernelshap' object representing \n# - SHAP matrix of dimension 500 x 4 \n# - feature data.frame\/matrix of dimension 500 x 4 \n# - baseline value of 3744.153\n# \n# SHAP values of first 2 observations:\n#         carat     color       cut   clarity\n# [1,] -110.738 -240.2758  5.254733 -720.3610\n# [2,] 2379.065  263.3112 56.413680  452.3044\n# \n# Corresponding standard errors:\n#         carat      color       cut  clarity\n# [1,] 2.064393 0.05113337 0.1374942 2.150754\n# [2,] 2.614281 0.84934844 0.9373701 0.827563\n\nsv &lt;- shapviz(ks, X = diamonds[ind, x])\nsv_waterfall(sv, 1)\nsv_importance(sv, \"both\")\nsv_dependence(sv, \"carat\", \"auto\")<\/pre><\/div>\n\n<\/div><\/div>\n\t\t<\/div>\n\n\n<p>Note the small standard errors of the SHAP values of the first two diamonds. They are only approximate because the background data is only a sample from an unknown population. Still, they give a good impression on the stability of the results.<\/p>\n\n\n\n<p>The waterfall plot shows a diamond with not super nice clarity and color, pulling down the value of this diamond. Note that, even if the model is working with scaled numeric feature values, the plot shows the original feature values.<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"683\" height=\"569\" src=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/waterfall.png\" alt=\"\" class=\"wp-image-943\" srcset=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/waterfall.png 683w, https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/waterfall-300x250.png 300w\" sizes=\"auto, (max-width: 683px) 100vw, 683px\" \/><figcaption>SHAP waterfall plot of one diamond. Note its bad clarity.<\/figcaption><\/figure>\n\n\n\n<p>The SHAP summary plot shows that &#8220;carat&#8221; is, unsurprisingly, the most important variable and that high carat mean high value. &#8220;cut&#8221; is not very important, except if it is extremely bad.<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"683\" height=\"569\" src=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/imp.png\" alt=\"\" class=\"wp-image-944\" srcset=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/imp.png 683w, https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/imp-300x250.png 300w\" sizes=\"auto, (max-width: 683px) 100vw, 683px\" \/><figcaption>SHAP summary plot with bars representing average absolute values as measure of importance.<\/figcaption><\/figure>\n\n\n\n<p>Our last plot is a SHAP dependence plot for &#8220;carat&#8221;: the effect makes sense, and we can spot some interaction with color. For worse colors (H-J), the effect of carat is a bit less strong as for the very white diamonds.<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"683\" height=\"569\" src=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/dep.png\" alt=\"\" class=\"wp-image-945\" srcset=\"https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/dep.png 683w, https:\/\/lorentzen.ch\/wp-content\/uploads\/2022\/08\/dep-300x250.png 300w\" sizes=\"auto, (max-width: 683px) 100vw, 683px\" \/><figcaption>Dependence plot for &#8220;carat&#8221;<\/figcaption><\/figure>\n\n\n\n<h3 class=\"wp-block-heading\">Short wrap-up<\/h3>\n\n\n\n<ul class=\"wp-block-list\"><li>Standard Kernel SHAP in R, yeahhhhh \ud83d\ude42<\/li><li>The Github version is relatively fast, so you can even decompose 500 observations of a deep learning model within 1-2 minutes.<\/li><\/ul>\n\n\n\n<p>The complete R script can be found <a href=\"https:\/\/github.com\/lorentzenchr\/notebooks\/blob\/master\/blogposts\/2022-08-12%20kernelshap.R\">here<\/a>.<\/p>\n\n\n\n<p><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Standard Kernel SHAP has arrived in R. We show how well it plays together with deep learning in Keras<\/p>\n","protected":false},"author":2,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[16,17,9],"tags":[5],"class_list":["post-939","post","type-post","status-publish","format-standard","hentry","category-machine-learning","category-programming","category-statistics","tag-r"],"featured_image_src":null,"author_info":{"display_name":"Michael Mayer","author_link":"https:\/\/lorentzen.ch\/index.php\/author\/michael\/"},"_links":{"self":[{"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/posts\/939","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/users\/2"}],"replies":[{"embeddable":true,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/comments?post=939"}],"version-history":[{"count":5,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/posts\/939\/revisions"}],"predecessor-version":[{"id":949,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/posts\/939\/revisions\/949"}],"wp:attachment":[{"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/media?parent=939"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/categories?post=939"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/lorentzen.ch\/index.php\/wp-json\/wp\/v2\/tags?post=939"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}