1{ lib
2, buildPythonPackage
3, fetchPypi
4, jax
5, jaxlib
6, multipledispatch
7, numpy
8, pytestCheckHook
9, pythonOlder
10, tqdm
11}:
12
13buildPythonPackage rec {
14 pname = "numpyro";
15 version = "0.10.1";
16 format = "setuptools";
17
18 disabled = pythonOlder "3.7";
19
20 src = fetchPypi {
21 inherit version pname;
22 hash = "sha256-36iW8ByN9D3dQWY68rPi/Erqc0ieZpR06DMpsYOykVA=";
23 };
24
25 propagatedBuildInputs = [
26 jax
27 jaxlib
28 numpy
29 multipledispatch
30 tqdm
31 ];
32
33 checkInputs = [
34 pytestCheckHook
35 ];
36
37 pythonImportsCheck = [
38 "numpyro"
39 ];
40
41 disabledTests = [
42 # AssertionError due to tolerance issues
43 "test_beta_binomial_log_prob"
44 "test_collapse_beta"
45 "test_cpu"
46 "test_gamma_poisson"
47 "test_gof"
48 "test_hpdi"
49 "test_kl_univariate"
50 "test_mean_var"
51 # Tests want to download data
52 "data_load"
53 "test_jsb_chorales"
54 ];
55
56 meta = with lib; {
57 description = "Library for probabilistic programming with NumPy";
58 homepage = "https://num.pyro.ai/";
59 license = licenses.asl20;
60 maintainers = with maintainers; [ fab ];
61 };
62}