1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pytestCheckHook,
6 pythonOlder,
7 writeText,
8 catboost,
9 cloudpickle,
10 ipython,
11 lightgbm,
12 lime,
13 matplotlib,
14 numba,
15 numpy,
16 opencv4,
17 pandas,
18 pyspark,
19 pytest-mpl,
20 scikit-learn,
21 scipy,
22 sentencepiece,
23 setuptools,
24 setuptools-scm,
25 slicer,
26 tqdm,
27 transformers,
28 xgboost,
29}:
30
31buildPythonPackage rec {
32 pname = "shap";
33 version = "0.46.0";
34 pyproject = true;
35
36 disabled = pythonOlder "3.8";
37
38 src = fetchFromGitHub {
39 owner = "slundberg";
40 repo = "shap";
41 tag = "v${version}";
42 hash = "sha256-qW36/Xw5oaYKmaMfE5euzkED9CKkjl2O55aO0OpCkfI=";
43 };
44
45 postPatch = ''
46 substituteInPlace pyproject.toml \
47 --replace-fail "numpy>=2.0" "numpy"
48 '';
49
50 build-system = [
51 numpy
52 setuptools
53 setuptools-scm
54 ];
55
56 dependencies = [
57 cloudpickle
58 numba
59 numpy
60 pandas
61 scikit-learn
62 scipy
63 slicer
64 tqdm
65 ];
66
67 optional-dependencies = {
68 plots = [
69 matplotlib
70 ipython
71 ];
72 others = [ lime ];
73 };
74
75 preCheck =
76 let
77 # This pytest hook mocks and catches attempts at accessing the network
78 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
79 conftestSkipNetworkErrors = writeText "conftest.py" ''
80 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
81 import urllib, requests, transformers
82
83 class NetworkAccessDeniedError(RuntimeError): pass
84 def deny_network_access(*a, **kw):
85 raise NetworkAccessDeniedError
86
87 requests.head = deny_network_access
88 requests.get = deny_network_access
89 urllib.request.urlopen = deny_network_access
90 urllib.request.Request = deny_network_access
91 transformers.AutoTokenizer.from_pretrained = deny_network_access
92
93 def pytest_runtest_makereport(item, call):
94 tr = orig_pytest_runtest_makereport(item, call)
95 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
96 tr.outcome = 'skipped'
97 tr.wasxfail = "reason: Requires network access."
98 return tr
99 '';
100 in
101 ''
102 export HOME=$TMPDIR
103 # when importing the local copy the extension is not found
104 rm -r shap
105
106 # Add pytest hook skipping tests that access network.
107 # These tests are marked as "Expected fail" (xfail)
108 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
109 '';
110
111 nativeCheckInputs = [
112 ipython
113 matplotlib
114 pytest-mpl
115 pytestCheckHook
116 # optional dependencies, which only serve to enable more tests:
117 catboost
118 lightgbm
119 opencv4
120 pyspark
121 sentencepiece
122 #torch # we already skip all its tests due to slowness, adding it does nothing
123 transformers
124 xgboost
125 ];
126
127 # Test startup hangs with 0.43.0 and Hydra ends with a timeout
128 doCheck = false;
129
130 disabledTestPaths = [
131 # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
132 # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
133 "tests/plots/test_summary.py" # FIXME: enable
134 ];
135
136 disabledTests = [
137 # The same reason as above test_summary.py
138 "test_random_force_plot_negative_sign"
139 "test_random_force_plot_positive_sign"
140 "test_random_summary_layered_violin_with_data2"
141 "test_random_summary_violin_with_data2"
142 "test_simple_bar_with_cohorts_dict"
143 ];
144
145 pythonImportsCheck = [ "shap" ];
146
147 meta = with lib; {
148 description = "Unified approach to explain the output of any machine learning model";
149 homepage = "https://github.com/slundberg/shap";
150 changelog = "https://github.com/slundberg/shap/releases/tag/v${version}";
151 license = licenses.mit;
152 maintainers = with maintainers; [
153 evax
154 natsukium
155 ];
156 };
157}