1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 jax,
12 jaxlib,
13 multipledispatch,
14 numpy,
15 tqdm,
16
17 # tests
18 dm-haiku,
19 flax,
20 funsor,
21 graphviz,
22 optax,
23 pyro-api,
24 pytest-xdist,
25 pytestCheckHook,
26 scikit-learn,
27 tensorflow-probability,
28}:
29
30buildPythonPackage rec {
31 pname = "numpyro";
32 version = "0.18.0";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "pyro-ppl";
37 repo = "numpyro";
38 tag = version;
39 hash = "sha256-0X/ta2yfzjf3JnZYdUAzQmXvbsDpwFCJe/bArMSWQgU=";
40 };
41
42 build-system = [ setuptools ];
43
44 dependencies = [
45 jax
46 jaxlib
47 multipledispatch
48 numpy
49 tqdm
50 ];
51
52 nativeCheckInputs = [
53 dm-haiku
54 flax
55 funsor
56 graphviz
57 optax
58 pyro-api
59 pytest-xdist
60 pytestCheckHook
61 scikit-learn
62 tensorflow-probability
63 ];
64
65 pythonImportsCheck = [ "numpyro" ];
66
67 pytestFlagsArray = [
68 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs)
69 "--maxprocesses=8"
70
71 # A few tests fail with:
72 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1.
73 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program.
74 # You can double-check how many devices are available in your system using `jax.local_device_count()`.
75 "-W"
76 "ignore::UserWarning"
77 ];
78
79 disabledTests =
80 [
81 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4)
82 "test_mcmc_parallel_chain"
83
84 # AssertionError due to tolerance issues
85 "test_bijective_transforms"
86 "test_cpu"
87 "test_entropy_categorical"
88 "test_gaussian_model"
89
90 # > with pytest.warns(UserWarning, match="Hessian of log posterior"):
91 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
92 # E Emitted warnings: [].
93 "test_laplace_approximation_warning"
94
95 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)
96 "test_chain"
97 ]
98 ++ lib.optionals stdenv.hostPlatform.isDarwin [
99 # AssertionError: Not equal to tolerance rtol=0.06, atol=0
100 "test_functional_map"
101 ];
102
103 disabledTestPaths = [
104 # Require internet access
105 "test/test_example_utils.py"
106 ];
107
108 meta = {
109 description = "Library for probabilistic programming with NumPy";
110 homepage = "https://num.pyro.ai/";
111 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
112 license = lib.licenses.asl20;
113 maintainers = with lib.maintainers; [ fab ];
114 };
115}