nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 123 lines 3.2 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 jax, 12 jaxlib, 13 multipledispatch, 14 numpy, 15 tqdm, 16 17 # tests 18 dm-haiku, 19 equinox, 20 flax, 21 funsor, 22 graphviz, 23 optax, 24 pyro-api, 25 pytest-xdist, 26 pytestCheckHook, 27 scikit-learn, 28 tensorflow-probability, 29}: 30 31buildPythonPackage rec { 32 pname = "numpyro"; 33 version = "0.19.0"; 34 pyproject = true; 35 36 src = fetchFromGitHub { 37 owner = "pyro-ppl"; 38 repo = "numpyro"; 39 tag = version; 40 hash = "sha256-3kzaINsz1Mjk97ERQsQIYIBz7CVmXtVDn0edJFMHQWs="; 41 }; 42 43 build-system = [ setuptools ]; 44 45 dependencies = [ 46 jax 47 jaxlib 48 multipledispatch 49 numpy 50 tqdm 51 ]; 52 53 nativeCheckInputs = [ 54 dm-haiku 55 equinox 56 flax 57 funsor 58 graphviz 59 optax 60 pyro-api 61 pytest-xdist 62 pytestCheckHook 63 scikit-learn 64 tensorflow-probability 65 ]; 66 67 pythonImportsCheck = [ "numpyro" ]; 68 69 pytestFlags = [ 70 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs) 71 "--maxprocesses=8" 72 73 # A few tests fail with: 74 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. 75 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. 76 # You can double-check how many devices are available in your system using `jax.local_device_count()`. 77 "-Wignore::UserWarning" 78 79 # FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar. 80 "-Wignore::FutureWarning" 81 ]; 82 83 disabledTests = [ 84 # ValueError: Found unexpected Arrays on value of type <class 'list'> in static attribute 'layers' 85 # of Pytree type '<class 'test_module.test_random_nnx_module_mcmc_sequence_params.<locals>.MLP'>'. 86 # This is an error starting from Flax version 0.12.0. 87 "test_random_nnx_module_mcmc_sequence_param" 88 89 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4) 90 "test_mcmc_parallel_chain" 91 92 # AssertionError due to tolerance issues 93 "test_bijective_transforms" 94 "test_cpu" 95 "test_entropy_categorical" 96 "test_gaussian_model" 97 98 # > with pytest.warns(UserWarning, match="Hessian of log posterior"): 99 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 100 # E Emitted warnings: []. 101 "test_laplace_approximation_warning" 102 103 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2) 104 "test_chain" 105 ] 106 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 107 # AssertionError: Not equal to tolerance rtol=0.06, atol=0 108 "test_functional_map" 109 ]; 110 111 disabledTestPaths = [ 112 # Require internet access 113 "test/test_example_utils.py" 114 ]; 115 116 meta = { 117 description = "Library for probabilistic programming with NumPy"; 118 homepage = "https://num.pyro.ai/"; 119 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}"; 120 license = lib.licenses.asl20; 121 maintainers = with lib.maintainers; [ fab ]; 122 }; 123}