1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, pytestCheckHook
5, pythonOlder
6, writeText
7, catboost
8, cloudpickle
9, ipython
10, lightgbm
11, lime
12, matplotlib
13, nose
14, numba
15, numpy
16, oldest-supported-numpy
17, opencv4
18, pandas
19, pyspark
20, pytest-mpl
21, scikit-learn
22, scipy
23, sentencepiece
24, setuptools
25, slicer
26, tqdm
27, transformers
28, xgboost
29, wheel
30}:
31
32buildPythonPackage rec {
33 pname = "shap";
34 version = "0.43.0";
35 pyproject = true;
36
37 disabled = pythonOlder "3.8";
38
39 src = fetchFromGitHub {
40 owner = "slundberg";
41 repo = "shap";
42 rev = "refs/tags/v${version}";
43 hash = "sha256-ylkpXhaLXsQiu6YMC3pUtlicptQmtjITzW+ydinB4ls=";
44 };
45
46 nativeBuildInputs = [
47 oldest-supported-numpy
48 setuptools
49 wheel
50 ];
51
52 propagatedBuildInputs = [
53 cloudpickle
54 numba
55 numpy
56 pandas
57 scikit-learn
58 scipy
59 slicer
60 tqdm
61 ];
62
63 passthru.optional-dependencies = {
64 plots = [ matplotlib ipython ];
65 others = [ lime ];
66 };
67
68 preCheck = let
69 # This pytest hook mocks and catches attempts at accessing the network
70 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
71 conftestSkipNetworkErrors = writeText "conftest.py" ''
72 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
73 import urllib, requests, transformers
74
75 class NetworkAccessDeniedError(RuntimeError): pass
76 def deny_network_access(*a, **kw):
77 raise NetworkAccessDeniedError
78
79 requests.head = deny_network_access
80 requests.get = deny_network_access
81 urllib.request.urlopen = deny_network_access
82 urllib.request.Request = deny_network_access
83 transformers.AutoTokenizer.from_pretrained = deny_network_access
84
85 def pytest_runtest_makereport(item, call):
86 tr = orig_pytest_runtest_makereport(item, call)
87 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
88 tr.outcome = 'skipped'
89 tr.wasxfail = "reason: Requires network access."
90 return tr
91 '';
92 in ''
93 export HOME=$TMPDIR
94 # when importing the local copy the extension is not found
95 rm -r shap
96
97 # Add pytest hook skipping tests that access network.
98 # These tests are marked as "Expected fail" (xfail)
99 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
100 '';
101
102 nativeCheckInputs = [
103 ipython
104 matplotlib
105 nose
106 pytest-mpl
107 pytestCheckHook
108 # optional dependencies, which only serve to enable more tests:
109 catboost
110 lightgbm
111 opencv4
112 pyspark
113 sentencepiece
114 #torch # we already skip all its tests due to slowness, adding it does nothing
115 transformers
116 xgboost
117 ];
118
119 disabledTestPaths = [
120 # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
121 # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
122 "tests/plots/test_summary.py" # FIXME: enable
123 ];
124
125 disabledTests = [
126 # The same reason as above test_summary.py
127 "test_random_force_plot_negative_sign"
128 "test_random_force_plot_positive_sign"
129 "test_random_summary_layered_violin_with_data2"
130 "test_random_summary_violin_with_data2"
131 "test_simple_bar_with_cohorts_dict"
132 ];
133
134 pythonImportsCheck = [
135 "shap"
136 "shap.explainers"
137 "shap.explainers.other"
138 "shap.plots"
139 "shap.plots.colors"
140 "shap.benchmark"
141 "shap.maskers"
142 "shap.utils"
143 "shap.actions"
144 "shap.models"
145 ];
146
147 meta = with lib; {
148 description = "A unified approach to explain the output of any machine learning model";
149 homepage = "https://github.com/slundberg/shap";
150 changelog = "https://github.com/slundberg/shap/releases/tag/v${version}";
151 license = licenses.mit;
152 maintainers = with maintainers; [ evax natsukium ];
153 };
154}