1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 rustPlatform,
7
8 # build-system
9 cargo,
10 rustc,
11
12 # dependencies
13 arviz,
14 pandas,
15 pyarrow,
16 xarray,
17
18 # tests
19 # bridgestan, (not packaged)
20 equinox,
21 flowjax,
22 jax,
23 jaxlib,
24 numba,
25 pytest-timeout,
26 pymc,
27 pytestCheckHook,
28 setuptools,
29 writableTmpDirAsHomeHook,
30}:
31
32buildPythonPackage rec {
33 pname = "nutpie";
34 version = "0.14.3";
35 pyproject = true;
36
37 src = fetchFromGitHub {
38 owner = "pymc-devs";
39 repo = "nutpie";
40 tag = "v${version}";
41 hash = "sha256-l2TEGa9VVJmU4mKZwfUdhiloW6Bh41OqIQzTRvYK3eg=";
42 };
43
44 cargoDeps = rustPlatform.fetchCargoVendor {
45 inherit src;
46 name = "${pname}-${version}";
47 hash = "sha256-hPKT+YM9s7XZhI3sfnLBfokbGQhwDa9y5Fgg1TItO4M=";
48 };
49
50 build-system = [
51 cargo
52 rustPlatform.bindgenHook
53 rustPlatform.cargoSetupHook
54 rustPlatform.maturinBuildHook
55 rustc
56 ];
57
58 pythonRelaxDeps = [
59 "xarray"
60 ];
61
62 dependencies = [
63 arviz
64 pandas
65 pyarrow
66 xarray
67 ];
68
69 pythonImportsCheck = [ "nutpie" ];
70
71 nativeCheckInputs = [
72 # bridgestan
73 equinox
74 flowjax
75 numba
76 jax
77 jaxlib
78 pymc
79 pytest-timeout
80 pytestCheckHook
81 setuptools
82 writableTmpDirAsHomeHook
83 ];
84
85 disabledTests = lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
86 # flaky (assert np.float64(0.0017554642626285276) > 0.01)
87 "test_normalizing_flow"
88 ];
89
90 disabledTestPaths = [
91 # Require unpackaged bridgestan
92 "tests/test_stan.py"
93 ];
94
95 meta = {
96 description = "Python wrapper for nuts-rs";
97 homepage = "https://github.com/pymc-devs/nutpie";
98 changelog = "https://github.com/pymc-devs/nutpie/blob/v${version}/CHANGELOG.md";
99 license = lib.licenses.mit;
100 maintainers = with lib.maintainers; [ GaetanLepage ];
101 };
102}