1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, cython
5, versioneer
6, cons
7, etuples
8, filelock
9, logical-unification
10, minikanren
11, numpy
12, scipy
13, typing-extensions
14, jax
15, jaxlib
16, numba
17, numba-scipy
18, pytest-mock
19, pytestCheckHook
20, pythonOlder
21, tensorflow-probability
22, stdenv
23}:
24
25buildPythonPackage rec {
26 pname = "pytensor";
27 version = "2.18.1";
28 pyproject = true;
29
30 disabled = pythonOlder "3.9";
31
32 src = fetchFromGitHub {
33 owner = "pymc-devs";
34 repo = "pytensor";
35 rev = "refs/tags/rel-${version}";
36 hash = "sha256-8bt6ps5bwT+Atr6JgQMxe234bL/ZriYlURUdX0sC1kk=";
37 };
38
39 postPatch = ''
40 substituteInPlace pyproject.toml \
41 --replace "versioneer[toml]==0.28" "versioneer[toml]"
42 '';
43
44 nativeBuildInputs = [
45 cython
46 versioneer
47 ];
48
49 propagatedBuildInputs = [
50 cons
51 etuples
52 filelock
53 logical-unification
54 minikanren
55 numpy
56 scipy
57 typing-extensions
58 ];
59
60 nativeCheckInputs = [
61 jax
62 jaxlib
63 numba
64 numba-scipy
65 pytest-mock
66 pytestCheckHook
67 tensorflow-probability
68 ];
69
70 preBuild = ''
71 export HOME=$(mktemp -d)
72 '';
73
74 pythonImportsCheck = [
75 "pytensor"
76 ];
77
78 disabledTests = [
79 # benchmarks (require pytest-benchmark):
80 "test_elemwise_speed"
81 "test_fused_elemwise_benchmark"
82 "test_logsumexp_benchmark"
83 "test_scan_multiple_output"
84 "test_vector_taps_benchmark"
85 ];
86
87 disabledTestPaths = [
88 # Don't run the most compute-intense tests
89 "tests/scan/"
90 "tests/tensor/"
91 "tests/sparse/sandbox/"
92 ];
93
94 meta = with lib; {
95 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
96 homepage = "https://github.com/pymc-devs/pytensor";
97 changelog = "https://github.com/pymc-devs/pytensor/releases";
98 license = licenses.bsd3;
99 maintainers = with maintainers; [ bcdarwin ];
100 broken = (stdenv.isLinux && stdenv.isAarch64);
101 };
102}