1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, cython
5, versioneer
6, cons
7, etuples
8, filelock
9, logical-unification
10, minikanren
11, numpy
12, scipy
13, typing-extensions
14, jax
15, jaxlib
16, numba
17, numba-scipy
18, pytest-mock
19, pytestCheckHook
20, pythonOlder
21# Tensorflow is currently (2023/10/04) broken.
22# Thus, we don't provide this optional test dependency.
23# , tensorflow-probability
24, stdenv
25}:
26
27buildPythonPackage rec {
28 pname = "pytensor";
29 version = "2.17.3";
30 pyproject = true;
31
32 disabled = pythonOlder "3.9";
33
34 src = fetchFromGitHub {
35 owner = "pymc-devs";
36 repo = "pytensor";
37 rev = "refs/tags/rel-${version}";
38 hash = "sha256-FufPCFzSjG8BrHes7t3XsdovX9gqUBG0gMDGKvkRkSA=";
39 };
40
41 postPatch = ''
42 substituteInPlace pyproject.toml \
43 --replace "versioneer[toml]==0.28" "versioneer[toml]"
44 '';
45
46 nativeBuildInputs = [
47 cython
48 versioneer
49 ];
50
51 propagatedBuildInputs = [
52 cons
53 etuples
54 filelock
55 logical-unification
56 minikanren
57 numpy
58 scipy
59 typing-extensions
60 ];
61
62 nativeCheckInputs = [
63 jax
64 jaxlib
65 numba
66 numba-scipy
67 pytest-mock
68 pytestCheckHook
69 # Tensorflow is currently (2023/10/04) broken.
70 # Thus, we don't provide this optional test dependency.
71 # tensorflow-probability
72 ];
73
74 preBuild = ''
75 export HOME=$(mktemp -d)
76 '';
77
78 pythonImportsCheck = [
79 "pytensor"
80 ];
81
82 disabledTests = [
83 # benchmarks (require pytest-benchmark):
84 "test_elemwise_speed"
85 "test_fused_elemwise_benchmark"
86 "test_logsumexp_benchmark"
87 "test_scan_multiple_output"
88 "test_vector_taps_benchmark"
89 # Temporarily disabled because of broken tensorflow-probability
90 "test_tfp_ops"
91 ];
92
93 disabledTestPaths = [
94 # Don't run the most compute-intense tests
95 "tests/scan/"
96 "tests/tensor/"
97 "tests/sparse/sandbox/"
98 ];
99
100 meta = with lib; {
101 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
102 homepage = "https://github.com/pymc-devs/pytensor";
103 changelog = "https://github.com/pymc-devs/pytensor/releases";
104 license = licenses.bsd3;
105 maintainers = with maintainers; [ bcdarwin ];
106 broken = (stdenv.isLinux && stdenv.isAarch64);
107 };
108}