1{ lib 2, buildPythonPackage 3, pythonOlder 4, fetchPypi 5, jax 6, jaxlib 7, multipledispatch 8, numpy 9, tqdm 10, funsor 11, pytestCheckHook 12# TODO: uncomment when tensorflow-probability gets fixed. 13# , tensorflow-probability 14}: 15 16buildPythonPackage rec { 17 pname = "numpyro"; 18 version = "0.13.2"; 19 format = "setuptools"; 20 21 disabled = pythonOlder "3.9"; 22 23 src = fetchPypi { 24 inherit version pname; 25 hash = "sha256-Um8LFVGAlMeOaN9uMwycHJzqEnTaxp8FYXIk+m2VTug="; 26 }; 27 28 propagatedBuildInputs = [ 29 jax 30 jaxlib 31 multipledispatch 32 numpy 33 tqdm 34 ]; 35 36 nativeCheckInputs = [ 37 funsor 38 pytestCheckHook 39 # TODO: uncomment when tensorflow-probability gets fixed. 40 # tensorflow-probability 41 ]; 42 43 pythonImportsCheck = [ 44 "numpyro" 45 ]; 46 47 disabledTests = [ 48 # AssertionError due to tolerance issues 49 "test_beta_binomial_log_prob" 50 "test_collapse_beta" 51 "test_cpu" 52 "test_gamma_poisson" 53 "test_gof" 54 "test_hpdi" 55 "test_kl_dirichlet_dirichlet" 56 "test_kl_univariate" 57 "test_mean_var" 58 # Tests want to download data 59 "data_load" 60 "test_jsb_chorales" 61 # RuntimeWarning: overflow encountered in cast 62 "test_zero_inflated_logits_probs_agree" 63 # NameError: unbound axis name: _provenance 64 "test_model_transformation" 65 ]; 66 67 # TODO: remove when tensorflow-probability gets fixed. 68 disabledTestPaths = [ 69 "test/test_distributions.py" 70 ]; 71 72 meta = with lib; { 73 description = "Library for probabilistic programming with NumPy"; 74 homepage = "https://num.pyro.ai/"; 75 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}"; 76 license = licenses.asl20; 77 maintainers = with maintainers; [ fab ]; 78 }; 79}