nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 jax,
12 jaxlib,
13 multipledispatch,
14 numpy,
15 tqdm,
16
17 # tests
18 dm-haiku,
19 equinox,
20 flax,
21 funsor,
22 graphviz,
23 optax,
24 pyro-api,
25 pytest-xdist,
26 pytestCheckHook,
27 scikit-learn,
28 tensorflow-probability,
29}:
30
31buildPythonPackage rec {
32 pname = "numpyro";
33 version = "0.19.0";
34 pyproject = true;
35
36 src = fetchFromGitHub {
37 owner = "pyro-ppl";
38 repo = "numpyro";
39 tag = version;
40 hash = "sha256-3kzaINsz1Mjk97ERQsQIYIBz7CVmXtVDn0edJFMHQWs=";
41 };
42
43 build-system = [ setuptools ];
44
45 dependencies = [
46 jax
47 jaxlib
48 multipledispatch
49 numpy
50 tqdm
51 ];
52
53 nativeCheckInputs = [
54 dm-haiku
55 equinox
56 flax
57 funsor
58 graphviz
59 optax
60 pyro-api
61 pytest-xdist
62 pytestCheckHook
63 scikit-learn
64 tensorflow-probability
65 ];
66
67 pythonImportsCheck = [ "numpyro" ];
68
69 pytestFlags = [
70 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs)
71 "--maxprocesses=8"
72
73 # A few tests fail with:
74 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1.
75 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program.
76 # You can double-check how many devices are available in your system using `jax.local_device_count()`.
77 "-Wignore::UserWarning"
78
79 # FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
80 "-Wignore::FutureWarning"
81 ];
82
83 disabledTests = [
84 # ValueError: Found unexpected Arrays on value of type <class 'list'> in static attribute 'layers'
85 # of Pytree type '<class 'test_module.test_random_nnx_module_mcmc_sequence_params.<locals>.MLP'>'.
86 # This is an error starting from Flax version 0.12.0.
87 "test_random_nnx_module_mcmc_sequence_param"
88
89 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4)
90 "test_mcmc_parallel_chain"
91
92 # AssertionError due to tolerance issues
93 "test_bijective_transforms"
94 "test_cpu"
95 "test_entropy_categorical"
96 "test_gaussian_model"
97
98 # > with pytest.warns(UserWarning, match="Hessian of log posterior"):
99 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
100 # E Emitted warnings: [].
101 "test_laplace_approximation_warning"
102
103 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)
104 "test_chain"
105 ]
106 ++ lib.optionals stdenv.hostPlatform.isDarwin [
107 # AssertionError: Not equal to tolerance rtol=0.06, atol=0
108 "test_functional_map"
109 ];
110
111 disabledTestPaths = [
112 # Require internet access
113 "test/test_example_utils.py"
114 ];
115
116 meta = {
117 description = "Library for probabilistic programming with NumPy";
118 homepage = "https://num.pyro.ai/";
119 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
120 license = lib.licenses.asl20;
121 maintainers = with lib.maintainers; [ fab ];
122 };
123}