1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 jax,
12 jaxlib,
13 multipledispatch,
14 numpy,
15 tqdm,
16
17 # tests
18 dm-haiku,
19 flax,
20 funsor,
21 graphviz,
22 optax,
23 pyro-api,
24 pytest-xdist,
25 pytestCheckHook,
26 scikit-learn,
27 tensorflow-probability,
28}:
29
30buildPythonPackage rec {
31 pname = "numpyro";
32 version = "0.18.0";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "pyro-ppl";
37 repo = "numpyro";
38 tag = version;
39 hash = "sha256-0X/ta2yfzjf3JnZYdUAzQmXvbsDpwFCJe/bArMSWQgU=";
40 };
41
42 build-system = [ setuptools ];
43
44 dependencies = [
45 jax
46 jaxlib
47 multipledispatch
48 numpy
49 tqdm
50 ];
51
52 nativeCheckInputs = [
53 dm-haiku
54 flax
55 funsor
56 graphviz
57 optax
58 pyro-api
59 pytest-xdist
60 pytestCheckHook
61 scikit-learn
62 tensorflow-probability
63 ];
64
65 pythonImportsCheck = [ "numpyro" ];
66
67 pytestFlags = [
68 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs)
69 "--maxprocesses=8"
70
71 # A few tests fail with:
72 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1.
73 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program.
74 # You can double-check how many devices are available in your system using `jax.local_device_count()`.
75 "-Wignore::UserWarning"
76 ];
77
78 disabledTests = [
79 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4)
80 "test_mcmc_parallel_chain"
81
82 # AssertionError due to tolerance issues
83 "test_bijective_transforms"
84 "test_cpu"
85 "test_entropy_categorical"
86 "test_gaussian_model"
87
88 # > with pytest.warns(UserWarning, match="Hessian of log posterior"):
89 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
90 # E Emitted warnings: [].
91 "test_laplace_approximation_warning"
92
93 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)
94 "test_chain"
95
96 # Failing since flax==0.11.0
97 # KeyError: "No RngStream named 'dropout' found in Rngs."
98 # https://github.com/pyro-ppl/numpyro/issues/2055
99 "test_nnx_state_dropout_smoke"
100 ]
101 ++ lib.optionals stdenv.hostPlatform.isDarwin [
102 # AssertionError: Not equal to tolerance rtol=0.06, atol=0
103 "test_functional_map"
104 ];
105
106 disabledTestPaths = [
107 # Require internet access
108 "test/test_example_utils.py"
109 ];
110
111 meta = {
112 description = "Library for probabilistic programming with NumPy";
113 homepage = "https://num.pyro.ai/";
114 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
115 license = lib.licenses.asl20;
116 maintainers = with lib.maintainers; [ fab ];
117 };
118}