nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9 cython,
10 versioneer,
11
12 # dependencies
13 cons,
14 etuples,
15 filelock,
16 logical-unification,
17 minikanren,
18 numba,
19 numpy,
20 scipy,
21
22 # tests
23 jax,
24 jaxlib,
25 pytest-benchmark,
26 pytest-mock,
27 pytestCheckHook,
28 tensorflow-probability,
29 writableTmpDirAsHomeHook,
30
31 nix-update-script,
32}:
33
34buildPythonPackage (finalAttrs: {
35 pname = "pytensor";
36 version = "2.37.0";
37 pyproject = true;
38
39 src = fetchFromGitHub {
40 owner = "pymc-devs";
41 repo = "pytensor";
42 tag = "rel-${finalAttrs.version}";
43 postFetch = ''
44 sed -i 's/git_refnames = "[^"]*"/git_refnames = " (tag: ${finalAttrs.src.tag})"/' $out/pytensor/_version.py
45 '';
46 hash = "sha256-N6TYK/qMux/a0ktQyCGYHZg3g8S6To8g1FXnILk9HIw=";
47 };
48
49 build-system = [
50 setuptools
51 cython
52 versioneer
53 ];
54
55 dependencies = [
56 cons
57 etuples
58 filelock
59 logical-unification
60 minikanren
61 numba
62 numpy
63 scipy
64 setuptools
65 ];
66
67 nativeCheckInputs = [
68 jax
69 jaxlib
70 numba
71 pytest-benchmark
72 pytest-mock
73 pytestCheckHook
74 tensorflow-probability
75 writableTmpDirAsHomeHook
76 ];
77
78 pytestFlags = [ "--benchmark-disable" ];
79
80 pythonImportsCheck = [ "pytensor" ];
81
82 # Ensure that the installed package is used instead of the source files from the current workdir
83 preCheck = ''
84 rm -rf pytensor
85 '';
86
87 disabledTests = [
88 # TypeError: jax_funcified_fgraph() takes 2 positional arguments but 3 were given
89 "test_jax_Reshape_shape_graph_input"
90 ]
91 ++ lib.optionals stdenv.hostPlatform.isDarwin [
92 # Numerical assertion error
93 # tests.unittest_tools.WrongValue: WrongValue
94 "test_op_sd"
95 "test_op_ss"
96
97 # pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1)
98 "OpFromGraph"
99 "add"
100 "cls_ofg1"
101 "direct"
102 "multiply"
103 "test_AddDS"
104 "test_AddSD"
105 "test_AddSS"
106 "test_MulDS"
107 "test_MulSD"
108 "test_MulSS"
109 "test_NoOutputFromInplace"
110 "test_OpFromGraph"
111 "test_adv_sub1_sparse_grad"
112 "test_alloc"
113 "test_binary"
114 "test_borrow_input"
115 "test_borrow_output"
116 "test_cache_race_condition"
117 "test_check_for_aliased_inputs"
118 "test_clinker_literal_cache"
119 "test_csm_grad"
120 "test_csm_unsorted"
121 "test_csr_dense_grad"
122 "test_debugprint"
123 "test_ellipsis_einsum"
124 "test_empty_elemwise"
125 "test_flatten"
126 "test_fprop"
127 "test_get_item_list_grad"
128 "test_grad"
129 "test_infer_shape"
130 "test_jax_pad"
131 "test_kron"
132 "test_masked_input"
133 "test_max"
134 "test_modes"
135 "test_mul_s_v_grad"
136 "test_multiple_outputs"
137 "test_nnet"
138 "test_not_inplace"
139 "test_numba_Cholesky_grad"
140 "test_numba_pad"
141 "test_optimizations_preserved"
142 "test_overided_function"
143 "test_potential_output_aliasing_induced_by_updates"
144 "test_profiling"
145 "test_rebuild_strict"
146 "test_runtime_broadcast_c"
147 "test_scan_err1"
148 "test_scan_err2"
149 "test_shared"
150 "test_size_implied_by_broadcasted_parameters"
151 "test_solve_triangular_grad"
152 "test_structured_add_s_v_grad"
153 "test_structureddot_csc_grad"
154 "test_structureddot_csr_grad"
155 "test_sum"
156 "test_swap_SharedVariable_with_given"
157 "test_test_value_op"
158 "test_unary"
159 "test_unbroadcast"
160 "test_update_equiv"
161 "test_update_same"
162 ];
163
164 disabledTestPaths = [
165 # Don't run the most compute-intense tests
166 "tests/scan/"
167 "tests/tensor/"
168 ];
169
170 passthru.updateScript = nix-update-script {
171 extraArgs = [
172 "--version-regex"
173 "rel-(.+)"
174 ];
175 };
176
177 meta = {
178 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
179 mainProgram = "pytensor-cache";
180 homepage = "https://github.com/pymc-devs/pytensor";
181 changelog = "https://github.com/pymc-devs/pytensor/releases/tag/${finalAttrs.src.tag}";
182 license = lib.licenses.bsd3;
183 maintainers = with lib.maintainers; [
184 bcdarwin
185 ];
186 };
187})