1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 cython,
6 versioneer,
7 cons,
8 etuples,
9 filelock,
10 logical-unification,
11 minikanren,
12 numpy,
13 scipy,
14 typing-extensions,
15 jax,
16 jaxlib,
17 numba,
18 numba-scipy,
19 pytest-mock,
20 pytestCheckHook,
21 pythonOlder,
22 tensorflow-probability,
23}:
24
25buildPythonPackage rec {
26 pname = "pytensor";
27 version = "2.20.0";
28 pyproject = true;
29
30 disabled = pythonOlder "3.10";
31
32 src = fetchFromGitHub {
33 owner = "pymc-devs";
34 repo = "pytensor";
35 rev = "refs/tags/rel-${version}";
36 hash = "sha256-bvkOMer+zYSsiU4a147eUEZjjUeTVpb9f/hepMZZ3sE=";
37 };
38
39 postPatch = ''
40 substituteInPlace pyproject.toml \
41 --replace "versioneer[toml]==0.28" "versioneer[toml]"
42 '';
43
44 build-system = [
45 cython
46 versioneer
47 ];
48
49 dependencies = [
50 cons
51 etuples
52 filelock
53 logical-unification
54 minikanren
55 numpy
56 scipy
57 typing-extensions
58 ];
59
60 nativeCheckInputs = [
61 jax
62 jaxlib
63 numba
64 pytest-mock
65 pytestCheckHook
66 tensorflow-probability
67 ];
68
69 preBuild = ''
70 export HOME=$(mktemp -d)
71 '';
72
73 pythonImportsCheck = [ "pytensor" ];
74
75 disabledTests = [
76 # benchmarks (require pytest-benchmark):
77 "test_elemwise_speed"
78 "test_fused_elemwise_benchmark"
79 "test_logsumexp_benchmark"
80 "test_scan_multiple_output"
81 "test_vector_taps_benchmark"
82 ];
83
84 disabledTestPaths = [
85 # Don't run the most compute-intense tests
86 "tests/scan/"
87 "tests/tensor/"
88 "tests/sparse/sandbox/"
89 ];
90
91 meta = with lib; {
92 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
93 mainProgram = "pytensor-cache";
94 homepage = "https://github.com/pymc-devs/pytensor";
95 changelog = "https://github.com/pymc-devs/pytensor/releases";
96 license = licenses.bsd3;
97 maintainers = with maintainers; [
98 bcdarwin
99 ferrine
100 ];
101 };
102}