1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 pythonOlder,
6 buildPythonPackage,
7 pytestCheckHook,
8 setuptools,
9 matplotlib,
10 numpy,
11 packaging,
12 torch,
13 tqdm,
14 flask,
15 flask-compress,
16}:
17
18buildPythonPackage rec {
19 pname = "captum";
20 version = "0.7.0";
21 pyproject = true;
22
23 build-system = [ setuptools ];
24
25 src = fetchFromGitHub {
26 owner = "pytorch";
27 repo = "captum";
28 rev = "refs/tags/v${version}";
29 hash = "sha256-1VOvPqxn6CNnmv7M8fl7JrqRfJQUH2tnXRCUqKnl7i0=";
30 };
31
32 dependencies = [
33 matplotlib
34 numpy
35 packaging
36 torch
37 tqdm
38 ];
39
40 pythonImportsCheck = [ "captum" ];
41
42 nativeCheckInputs = [
43 pytestCheckHook
44 flask
45 flask-compress
46 ];
47
48 disabledTestPaths =
49 [
50 # These tests requires `parametrized` module (https://pypi.org/project/parametrized/) which seem to be unavailable on Nix.
51 "tests/attr/test_dataloader_attr.py"
52 "tests/attr/test_interpretable_input.py"
53 "tests/attr/test_llm_attr.py"
54 "tests/influence/_core/test_dataloader.py"
55 "tests/influence/_core/test_tracin_aggregate_influence.py"
56 "tests/influence/_core/test_tracin_intermediate_quantities.py"
57 "tests/influence/_core/test_tracin_k_most_influential.py"
58 "tests/influence/_core/test_tracin_regression.py"
59 "tests/influence/_core/test_tracin_self_influence.py"
60 "tests/influence/_core/test_tracin_show_progress.py"
61 "tests/influence/_core/test_tracin_validation.py"
62 "tests/influence/_core/test_tracin_xor.py"
63 "tests/insights/test_contribution.py"
64 "tests/module/test_binary_concrete_stochastic_gates.py"
65 "tests/module/test_gaussian_stochastic_gates.py"
66 ]
67 ++ lib.optionals stdenv.hostPlatform.isDarwin [
68 # These tests are failing on macOS:
69 # > E AttributeError: module 'torch.distributed' has no attribute 'init_process_group'
70 "tests/attr/test_data_parallel.py"
71 ]
72 ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [
73 # Issue reported upstream at https://github.com/pytorch/captum/issues/1447
74 "tests/concept/test_tcav.py"
75 ];
76
77 disabledTests = [
78 # Failing tests
79 "test_softmax_classification_batch_multi_target"
80 "test_softmax_classification_batch_zero_baseline"
81 ];
82
83 meta = {
84 description = "Model interpretability and understanding for PyTorch";
85 homepage = "https://github.com/pytorch/captum";
86 license = lib.licenses.bsd3;
87 maintainers = with lib.maintainers; [ drupol ];
88 };
89}