1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 cython,
9 versioneer,
10
11 # dependencies
12 cons,
13 etuples,
14 filelock,
15 logical-unification,
16 minikanren,
17 numpy,
18 scipy,
19
20 # tests
21 jax,
22 jaxlib,
23 numba,
24 pytest-benchmark,
25 pytest-mock,
26 pytestCheckHook,
27 tensorflow-probability,
28 writableTmpDirAsHomeHook,
29
30 nix-update-script,
31}:
32
33buildPythonPackage rec {
34 pname = "pytensor";
35 version = "2.30.3";
36 pyproject = true;
37
38 src = fetchFromGitHub {
39 owner = "pymc-devs";
40 repo = "pytensor";
41 tag = "rel-${version}";
42 postFetch = ''
43 sed -i 's/git_refnames = "[^"]*"/git_refnames = " (tag: ${src.tag})"/' $out/pytensor/_version.py
44 '';
45 hash = "sha256-Iyiuvt86pfz8MmpwgDecKJFVOw+fKpEaA9m1MBA9Yxs=";
46 };
47
48 build-system = [
49 cython
50 versioneer
51 ];
52
53 dependencies = [
54 cons
55 etuples
56 filelock
57 logical-unification
58 minikanren
59 numpy
60 scipy
61 ];
62
63 nativeCheckInputs = [
64 jax
65 jaxlib
66 numba
67 pytest-benchmark
68 pytest-mock
69 pytestCheckHook
70 tensorflow-probability
71 writableTmpDirAsHomeHook
72 ];
73
74 pythonImportsCheck = [ "pytensor" ];
75
76 # Ensure that the installed package is used instead of the source files from the current workdir
77 preCheck = ''
78 rm -rf pytensor
79 '';
80
81 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
82 # pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1)
83 "OpFromGraph"
84 "add"
85 "cls_ofg1"
86 "direct"
87 "multiply"
88 "test_AddDS"
89 "test_AddSD"
90 "test_AddSS"
91 "test_MulDS"
92 "test_MulSD"
93 "test_MulSS"
94 "test_NoOutputFromInplace"
95 "test_OpFromGraph"
96 "test_adv_sub1_sparse_grad"
97 "test_alloc"
98 "test_binary"
99 "test_borrow_input"
100 "test_borrow_output"
101 "test_cache_race_condition"
102 "test_check_for_aliased_inputs"
103 "test_clinker_literal_cache"
104 "test_csm_grad"
105 "test_csm_unsorted"
106 "test_csr_dense_grad"
107 "test_debugprint"
108 "test_ellipsis_einsum"
109 "test_empty_elemwise"
110 "test_flatten"
111 "test_fprop"
112 "test_get_item_list_grad"
113 "test_grad"
114 "test_infer_shape"
115 "test_jax_pad"
116 "test_kron"
117 "test_masked_input"
118 "test_max"
119 "test_modes"
120 "test_mul_s_v_grad"
121 "test_multiple_outputs"
122 "test_not_inplace"
123 "test_numba_Cholesky_grad"
124 "test_numba_pad"
125 "test_optimizations_preserved"
126 "test_overided_function"
127 "test_potential_output_aliasing_induced_by_updates"
128 "test_profiling"
129 "test_rebuild_strict"
130 "test_runtime_broadcast_c"
131 "test_scan_err1"
132 "test_scan_err2"
133 "test_shared"
134 "test_solve_triangular_grad"
135 "test_structured_add_s_v_grad"
136 "test_structureddot_csc_grad"
137 "test_structureddot_csr_grad"
138 "test_sum"
139 "test_swap_SharedVariable_with_given"
140 "test_test_value_op"
141 "test_unary"
142 "test_unbroadcast"
143 "test_update_equiv"
144 "test_update_same"
145 ];
146
147 disabledTestPaths = [
148 # Don't run the most compute-intense tests
149 "tests/scan/"
150 "tests/tensor/"
151 "tests/sparse/sandbox/"
152 ];
153
154 passthru.updateScript = nix-update-script {
155 extraArgs = [
156 "--version-regex"
157 "rel-(.+)"
158 ];
159 };
160
161 meta = {
162 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
163 mainProgram = "pytensor-cache";
164 homepage = "https://github.com/pymc-devs/pytensor";
165 changelog = "https://github.com/pymc-devs/pytensor/releases/tag/rel-${version}";
166 license = lib.licenses.bsd3;
167 maintainers = with lib.maintainers; [
168 bcdarwin
169 ferrine
170 ];
171 };
172}