at 25.11-pre 2.7 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 jax, 12 jaxlib, 13 multipledispatch, 14 numpy, 15 tqdm, 16 17 # tests 18 dm-haiku, 19 flax, 20 funsor, 21 graphviz, 22 optax, 23 pyro-api, 24 pytest-xdist, 25 pytestCheckHook, 26 scikit-learn, 27 tensorflow-probability, 28}: 29 30buildPythonPackage rec { 31 pname = "numpyro"; 32 version = "0.18.0"; 33 pyproject = true; 34 35 src = fetchFromGitHub { 36 owner = "pyro-ppl"; 37 repo = "numpyro"; 38 tag = version; 39 hash = "sha256-0X/ta2yfzjf3JnZYdUAzQmXvbsDpwFCJe/bArMSWQgU="; 40 }; 41 42 build-system = [ setuptools ]; 43 44 dependencies = [ 45 jax 46 jaxlib 47 multipledispatch 48 numpy 49 tqdm 50 ]; 51 52 nativeCheckInputs = [ 53 dm-haiku 54 flax 55 funsor 56 graphviz 57 optax 58 pyro-api 59 pytest-xdist 60 pytestCheckHook 61 scikit-learn 62 tensorflow-probability 63 ]; 64 65 pythonImportsCheck = [ "numpyro" ]; 66 67 pytestFlagsArray = [ 68 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs) 69 "--maxprocesses=8" 70 71 # A few tests fail with: 72 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. 73 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. 74 # You can double-check how many devices are available in your system using `jax.local_device_count()`. 75 "-W" 76 "ignore::UserWarning" 77 ]; 78 79 disabledTests = 80 [ 81 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4) 82 "test_mcmc_parallel_chain" 83 84 # AssertionError due to tolerance issues 85 "test_bijective_transforms" 86 "test_cpu" 87 "test_entropy_categorical" 88 "test_gaussian_model" 89 90 # > with pytest.warns(UserWarning, match="Hessian of log posterior"): 91 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 92 # E Emitted warnings: []. 93 "test_laplace_approximation_warning" 94 95 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2) 96 "test_chain" 97 ] 98 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 99 # AssertionError: Not equal to tolerance rtol=0.06, atol=0 100 "test_functional_map" 101 ]; 102 103 disabledTestPaths = [ 104 # Require internet access 105 "test/test_example_utils.py" 106 ]; 107 108 meta = { 109 description = "Library for probabilistic programming with NumPy"; 110 homepage = "https://num.pyro.ai/"; 111 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}"; 112 license = lib.licenses.asl20; 113 maintainers = with lib.maintainers; [ fab ]; 114 }; 115}