1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 cython,
9 versioneer,
10
11 # dependencies
12 cons,
13 etuples,
14 filelock,
15 logical-unification,
16 minikanren,
17 numpy,
18 scipy,
19
20 # checks
21 jax,
22 jaxlib,
23 numba,
24 pytestCheckHook,
25 pytest-mock,
26 tensorflow-probability,
27
28 nix-update-script,
29}:
30
31buildPythonPackage rec {
32 pname = "pytensor";
33 version = "2.26.3";
34 pyproject = true;
35
36 src = fetchFromGitHub {
37 owner = "pymc-devs";
38 repo = "pytensor";
39 rev = "refs/tags/rel-${version}";
40 hash = "sha256-RhicZSVkaDtIngIOvzyEQ+VMZwdV45wDk7e7bThTIh8=";
41 };
42
43 pythonRelaxDeps = [
44 "scipy"
45 ];
46
47 build-system = [
48 cython
49 versioneer
50 ];
51
52 dependencies = [
53 cons
54 etuples
55 filelock
56 logical-unification
57 minikanren
58 numpy
59 scipy
60 ];
61
62 nativeCheckInputs = [
63 jax
64 jaxlib
65 numba
66 pytestCheckHook
67 pytest-mock
68 tensorflow-probability
69 ];
70
71 preBuild = ''
72 export HOME=$(mktemp -d)
73 '';
74
75 pythonImportsCheck = [ "pytensor" ];
76
77 # Ensure that the installed package is used instead of the source files from the current workdir
78 preCheck = ''
79 rm -rf pytensor
80 '';
81
82 disabledTests =
83 [
84 # benchmarks (require pytest-benchmark):
85 "test_elemwise_speed"
86 "test_fused_elemwise_benchmark"
87 "test_logsumexp_benchmark"
88 "test_minimal_random_function_call_benchmark"
89 "test_scan_multiple_output"
90 "test_vector_taps_benchmark"
91
92 # Failure reported upstream: https://github.com/pymc-devs/pytensor/issues/980
93 "test_choose_signature"
94 ]
95 ++ lib.optionals stdenv.hostPlatform.isDarwin [
96 # pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1)
97 "OpFromGraph"
98 "add"
99 "cls_ofg1"
100 "direct"
101 "multiply"
102 "test_AddDS"
103 "test_AddSD"
104 "test_AddSS"
105 "test_MulDS"
106 "test_MulSD"
107 "test_MulSS"
108 "test_NoOutputFromInplace"
109 "test_OpFromGraph"
110 "test_adv_sub1_sparse_grad"
111 "test_binary"
112 "test_borrow_input"
113 "test_borrow_output"
114 "test_cache_race_condition"
115 "test_check_for_aliased_inputs"
116 "test_clinker_literal_cache"
117 "test_csm_grad"
118 "test_csm_unsorted"
119 "test_csr_dense_grad"
120 "test_debugprint"
121 "test_ellipsis_einsum"
122 "test_empty_elemwise"
123 "test_flatten"
124 "test_fprop"
125 "test_get_item_list_grad"
126 "test_grad"
127 "test_infer_shape"
128 "test_jax_pad"
129 "test_kron"
130 "test_masked_input"
131 "test_max"
132 "test_modes"
133 "test_mul_s_v_grad"
134 "test_multiple_outputs"
135 "test_not_inplace"
136 "test_numba_pad"
137 "test_optimizations_preserved"
138 "test_overided_function"
139 "test_potential_output_aliasing_induced_by_updates"
140 "test_profiling"
141 "test_rebuild_strict"
142 "test_runtime_broadcast_c"
143 "test_scan_err1"
144 "test_scan_err2"
145 "test_shared"
146 "test_structured_add_s_v_grad"
147 "test_structureddot_csc_grad"
148 "test_structureddot_csr_grad"
149 "test_sum"
150 "test_swap_SharedVariable_with_given"
151 "test_test_value_op"
152 "test_unary"
153 "test_unbroadcast"
154 "test_update_equiv"
155 "test_update_same"
156 ];
157
158 disabledTestPaths = [
159 # Don't run the most compute-intense tests
160 "tests/scan/"
161 "tests/tensor/"
162 "tests/sparse/sandbox/"
163 ];
164
165 passthru.updateScript = nix-update-script {
166 extraArgs = [
167 "--version-regex"
168 "rel-(.+)"
169 ];
170 };
171
172 meta = {
173 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
174 mainProgram = "pytensor-cache";
175 homepage = "https://github.com/pymc-devs/pytensor";
176 changelog = "https://github.com/pymc-devs/pytensor/releases/tag/${lib.removePrefix "refs/tags/" src.rev}";
177 license = lib.licenses.bsd3;
178 maintainers = with lib.maintainers; [
179 bcdarwin
180 ferrine
181 ];
182 };
183}