1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pytestCheckHook,
6 pythonOlder,
7 writeText,
8 catboost,
9 cloudpickle,
10 ipython,
11 lightgbm,
12 lime,
13 matplotlib,
14 nose,
15 numba,
16 numpy,
17 oldest-supported-numpy,
18 opencv4,
19 pandas,
20 pyspark,
21 pytest-mpl,
22 scikit-learn,
23 scipy,
24 sentencepiece,
25 setuptools,
26 setuptools-scm,
27 slicer,
28 tqdm,
29 transformers,
30 xgboost,
31}:
32
33buildPythonPackage rec {
34 pname = "shap";
35 version = "0.45.1";
36 pyproject = true;
37
38 disabled = pythonOlder "3.8";
39
40 src = fetchFromGitHub {
41 owner = "slundberg";
42 repo = "shap";
43 rev = "refs/tags/v${version}";
44 hash = "sha256-REMAubT9WRe0exfhO4UCLt3FFQHq4HApHnI6i2F/V1o=";
45 };
46
47 nativeBuildInputs = [
48 oldest-supported-numpy
49 setuptools
50 setuptools-scm
51 ];
52
53 propagatedBuildInputs = [
54 cloudpickle
55 numba
56 numpy
57 pandas
58 scikit-learn
59 scipy
60 slicer
61 tqdm
62 ];
63
64 passthru.optional-dependencies = {
65 plots = [
66 matplotlib
67 ipython
68 ];
69 others = [ lime ];
70 };
71
72 preCheck =
73 let
74 # This pytest hook mocks and catches attempts at accessing the network
75 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
76 conftestSkipNetworkErrors = writeText "conftest.py" ''
77 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
78 import urllib, requests, transformers
79
80 class NetworkAccessDeniedError(RuntimeError): pass
81 def deny_network_access(*a, **kw):
82 raise NetworkAccessDeniedError
83
84 requests.head = deny_network_access
85 requests.get = deny_network_access
86 urllib.request.urlopen = deny_network_access
87 urllib.request.Request = deny_network_access
88 transformers.AutoTokenizer.from_pretrained = deny_network_access
89
90 def pytest_runtest_makereport(item, call):
91 tr = orig_pytest_runtest_makereport(item, call)
92 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
93 tr.outcome = 'skipped'
94 tr.wasxfail = "reason: Requires network access."
95 return tr
96 '';
97 in
98 ''
99 export HOME=$TMPDIR
100 # when importing the local copy the extension is not found
101 rm -r shap
102
103 # Add pytest hook skipping tests that access network.
104 # These tests are marked as "Expected fail" (xfail)
105 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
106 '';
107
108 nativeCheckInputs = [
109 ipython
110 matplotlib
111 nose
112 pytest-mpl
113 pytestCheckHook
114 # optional dependencies, which only serve to enable more tests:
115 catboost
116 lightgbm
117 opencv4
118 pyspark
119 sentencepiece
120 #torch # we already skip all its tests due to slowness, adding it does nothing
121 transformers
122 xgboost
123 ];
124
125 # Test startup hangs with 0.43.0 and Hydra ends with a timeout
126 doCheck = false;
127
128 disabledTestPaths = [
129 # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
130 # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
131 "tests/plots/test_summary.py" # FIXME: enable
132 ];
133
134 disabledTests = [
135 # The same reason as above test_summary.py
136 "test_random_force_plot_negative_sign"
137 "test_random_force_plot_positive_sign"
138 "test_random_summary_layered_violin_with_data2"
139 "test_random_summary_violin_with_data2"
140 "test_simple_bar_with_cohorts_dict"
141 ];
142
143 pythonImportsCheck = [ "shap" ];
144
145 meta = with lib; {
146 description = "Unified approach to explain the output of any machine learning model";
147 homepage = "https://github.com/slundberg/shap";
148 changelog = "https://github.com/slundberg/shap/releases/tag/v${version}";
149 license = licenses.mit;
150 maintainers = with maintainers; [
151 evax
152 natsukium
153 ];
154 };
155}