1{ lib
2, buildPythonPackage
3, fetchPypi
4, jax
5, jaxlib
6, multipledispatch
7, numpy
8, pytestCheckHook
9, pythonOlder
10, tensorflow-probability
11, tqdm
12}:
13
14buildPythonPackage rec {
15 pname = "numpyro";
16 version = "0.11.0";
17 format = "setuptools";
18
19 disabled = pythonOlder "3.7";
20
21 src = fetchPypi {
22 inherit version pname;
23 hash = "sha256-01fdGgFZ+G1FwjNwitM6PT1TQx0FtLvs4dBorkFoqo4=";
24 };
25
26 propagatedBuildInputs = [
27 jax
28 jaxlib
29 numpy
30 multipledispatch
31 tqdm
32 ];
33
34 nativeCheckInputs = [
35 tensorflow-probability
36 pytestCheckHook
37 ];
38
39 pythonImportsCheck = [
40 "numpyro"
41 ];
42
43 disabledTests = [
44 # AssertionError due to tolerance issues
45 "test_beta_binomial_log_prob"
46 "test_collapse_beta"
47 "test_cpu"
48 "test_gamma_poisson"
49 "test_gof"
50 "test_hpdi"
51 "test_kl_dirichlet_dirichlet"
52 "test_kl_univariate"
53 "test_mean_var"
54 # Tests want to download data
55 "data_load"
56 "test_jsb_chorales"
57 # RuntimeWarning: overflow encountered in cast
58 "test_zero_inflated_logits_probs_agree"
59 # NameError: unbound axis name: _provenance
60 "test_model_transformation"
61 ];
62
63 meta = with lib; {
64 description = "Library for probabilistic programming with NumPy";
65 homepage = "https://num.pyro.ai/";
66 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
67 license = licenses.asl20;
68 maintainers = with maintainers; [ fab ];
69 };
70}