1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchPypi
5, jax
6, jaxlib
7, multipledispatch
8, numpy
9, tqdm
10, funsor
11, pytestCheckHook
12# TODO: uncomment when tensorflow-probability gets fixed.
13# , tensorflow-probability
14}:
15
16buildPythonPackage rec {
17 pname = "numpyro";
18 version = "0.13.2";
19 format = "setuptools";
20
21 disabled = pythonOlder "3.9";
22
23 src = fetchPypi {
24 inherit version pname;
25 hash = "sha256-Um8LFVGAlMeOaN9uMwycHJzqEnTaxp8FYXIk+m2VTug=";
26 };
27
28 propagatedBuildInputs = [
29 jax
30 jaxlib
31 multipledispatch
32 numpy
33 tqdm
34 ];
35
36 nativeCheckInputs = [
37 funsor
38 pytestCheckHook
39 # TODO: uncomment when tensorflow-probability gets fixed.
40 # tensorflow-probability
41 ];
42
43 pythonImportsCheck = [
44 "numpyro"
45 ];
46
47 disabledTests = [
48 # AssertionError due to tolerance issues
49 "test_beta_binomial_log_prob"
50 "test_collapse_beta"
51 "test_cpu"
52 "test_gamma_poisson"
53 "test_gof"
54 "test_hpdi"
55 "test_kl_dirichlet_dirichlet"
56 "test_kl_univariate"
57 "test_mean_var"
58 # Tests want to download data
59 "data_load"
60 "test_jsb_chorales"
61 # RuntimeWarning: overflow encountered in cast
62 "test_zero_inflated_logits_probs_agree"
63 # NameError: unbound axis name: _provenance
64 "test_model_transformation"
65 ];
66
67 # TODO: remove when tensorflow-probability gets fixed.
68 disabledTestPaths = [
69 "test/test_distributions.py"
70 ];
71
72 meta = with lib; {
73 description = "Library for probabilistic programming with NumPy";
74 homepage = "https://num.pyro.ai/";
75 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
76 license = licenses.asl20;
77 maintainers = with maintainers; [ fab ];
78 };
79}