1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchPypi,
6 setuptools,
7 jax,
8 jaxlib,
9 multipledispatch,
10 numpy,
11 tqdm,
12 funsor,
13 pytestCheckHook,
14# TODO: uncomment when tensorflow-probability gets fixed.
15# , tensorflow-probability
16}:
17
18buildPythonPackage rec {
19 pname = "numpyro";
20 version = "0.15.1";
21 pyproject = true;
22
23 disabled = pythonOlder "3.9";
24
25 src = fetchPypi {
26 inherit version pname;
27 hash = "sha256-HnX6sYRdEpbCMDXHsk1l/h60630ZwmED3SUioLA3wrU=";
28 };
29
30 build-system = [ setuptools ];
31
32 dependencies = [
33 jax
34 jaxlib
35 multipledispatch
36 numpy
37 tqdm
38 ];
39
40 nativeCheckInputs = [
41 funsor
42 pytestCheckHook
43 # TODO: uncomment when tensorflow-probability gets fixed.
44 # tensorflow-probability
45 ];
46
47 pythonImportsCheck = [ "numpyro" ];
48
49 disabledTests = [
50 # AssertionError due to tolerance issues
51 "test_beta_binomial_log_prob"
52 "test_collapse_beta"
53 "test_cpu"
54 "test_gamma_poisson"
55 "test_gof"
56 "test_hpdi"
57 "test_kl_dirichlet_dirichlet"
58 "test_kl_univariate"
59 "test_mean_var"
60 # Tests want to download data
61 "data_load"
62 "test_jsb_chorales"
63 # RuntimeWarning: overflow encountered in cast
64 "test_zero_inflated_logits_probs_agree"
65 # NameError: unbound axis name: _provenance
66 "test_model_transformation"
67 ];
68
69 # TODO: remove when tensorflow-probability gets fixed.
70 disabledTestPaths = [ "test/test_distributions.py" ];
71
72 meta = {
73 description = "Library for probabilistic programming with NumPy";
74 homepage = "https://num.pyro.ai/";
75 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
76 license = lib.licenses.asl20;
77 maintainers = with lib.maintainers; [ fab ];
78 };
79}