1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, writeText
5, isPy27
6, pytestCheckHook
7, pytest-mpl
8, numpy
9, scipy
10, scikit-learn
11, pandas
12, transformers
13, opencv4
14, lightgbm
15, catboost
16, pyspark
17, sentencepiece
18, tqdm
19, slicer
20, numba
21, matplotlib
22, nose
23, lime
24, cloudpickle
25, ipython
26}:
27
28buildPythonPackage rec {
29 pname = "shap";
30 version = "0.41.0";
31 disabled = isPy27;
32
33 src = fetchFromGitHub {
34 owner = "slundberg";
35 repo = pname;
36 rev = "refs/tags/v${version}";
37 sha256 = "sha256-rYVWQ3VRvIObSQPwDRsxhTOGOKNkYkLtiHzVwoB3iJ0=";
38 };
39
40 propagatedBuildInputs = [
41 numpy
42 scipy
43 scikit-learn
44 pandas
45 tqdm
46 slicer
47 numba
48 cloudpickle
49 ];
50
51 passthru.optional-dependencies = {
52 plots = [ matplotlib ipython ];
53 others = [ lime ];
54 };
55
56 preCheck = let
57 # This pytest hook mocks and catches attempts at accessing the network
58 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
59 conftestSkipNetworkErrors = writeText "conftest.py" ''
60 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
61 import urllib, requests
62
63 class NetworkAccessDeniedError(RuntimeError): pass
64 def deny_network_access(*a, **kw):
65 raise NetworkAccessDeniedError
66
67 requests.head = deny_network_access
68 requests.get = deny_network_access
69 urllib.request.urlopen = deny_network_access
70 urllib.request.Request = deny_network_access
71
72 def pytest_runtest_makereport(item, call):
73 tr = orig_pytest_runtest_makereport(item, call)
74 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
75 tr.outcome = 'skipped'
76 tr.wasxfail = "reason: Requires network access."
77 return tr
78 '';
79 in ''
80 export HOME=$TMPDIR
81 # when importing the local copy the extension is not found
82 rm -r shap
83
84 # coverage testing is a waste considering how much we have to skip
85 substituteInPlace pytest.ini \
86 --replace "--cov=shap --cov-report=term-missing" ""
87
88 # Add pytest hook skipping tests that access network.
89 # These tests are marked as "Expected fail" (xfail)
90 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
91 '';
92 checkInputs = [
93 pytestCheckHook
94 pytest-mpl
95 matplotlib
96 nose
97 ipython
98 # optional dependencies, which only serve to enable more tests:
99 opencv4
100 #pytorch # we already skip all its tests due to slowness, adding it does nothing
101 transformers
102 #xgboost # numerically unstable? xgboost tests randomly fails pending on nixpkgs revision
103 lightgbm
104 catboost
105 pyspark
106 sentencepiece
107 ];
108 disabledTestPaths = [
109 # takes forever without GPU acceleration
110 "tests/explainers/test_deep.py"
111 "tests/explainers/test_gradient.py"
112 # requires GPU. We skip here instead of having pytest repeatedly check for GPU
113 "tests/explainers/test_gpu_tree.py"
114 # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
115 # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
116 "tests/plots/test_summary.py" # FIXME: enable
117 # 100% of the tests in these paths require network
118 "tests/explainers/test_explainer.py"
119 "tests/explainers/test_exact.py"
120 "tests/explainers/test_partition.py"
121 "tests/maskers/test_fixed_composite.py"
122 "tests/maskers/test_text.py"
123 "tests/models/test_teacher_forcing_logits.py"
124 "tests/models/test_text_generation.py"
125 ];
126 disabledTests = [
127 # unstable. A xgboost-enabled test. possibly related: https://github.com/slundberg/shap/issues/2480
128 "test_provided_background_tree_path_dependent"
129 ];
130
131 #pytestFlagsArray = ["-x" "-W" "ignore"]; # uncomment this to debug
132
133 pythonImportCheck = [
134 "shap"
135 "shap.explainers"
136 "shap.explainers.other"
137 "shap.plots"
138 "shap.plots.colors"
139 "shap.benchmark"
140 "shap.maskers"
141 "shap.utils"
142 "shap.actions"
143 "shap.models"
144 ];
145
146 meta = with lib; {
147 description = "A unified approach to explain the output of any machine learning model";
148 homepage = "https://github.com/slundberg/shap";
149 license = licenses.mit;
150 maintainers = with maintainers; [ evax ];
151 platforms = platforms.unix;
152 };
153}