1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 pythonOlder, 6 fetchFromGitHub, 7 pytest-xdist, 8 pytestCheckHook, 9 setuptools-scm, 10 fastprogress, 11 jax, 12 jaxlib, 13 jaxopt, 14 optax, 15 typing-extensions, 16}: 17 18buildPythonPackage rec { 19 pname = "blackjax"; 20 version = "1.2.1"; 21 pyproject = true; 22 23 disabled = pythonOlder "3.9"; 24 25 src = fetchFromGitHub { 26 owner = "blackjax-devs"; 27 repo = "blackjax"; 28 rev = "refs/tags/${version}"; 29 hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak="; 30 }; 31 32 build-system = [ setuptools-scm ]; 33 34 dependencies = [ 35 fastprogress 36 jax 37 jaxlib 38 jaxopt 39 optax 40 typing-extensions 41 ]; 42 43 nativeCheckInputs = [ 44 pytestCheckHook 45 pytest-xdist 46 ]; 47 48 disabledTestPaths = 49 [ "tests/test_benchmarks.py" ] 50 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 51 # Assertion errors on numerical values 52 "tests/mcmc/test_integrators.py" 53 ]; 54 55 disabledTests = 56 [ 57 # too slow 58 "test_adaptive_tempered_smc" 59 ] 60 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 61 # Numerical test (AssertionError) 62 # https://github.com/blackjax-devs/blackjax/issues/668 63 "test_chees_adaptation" 64 ]; 65 66 pythonImportsCheck = [ "blackjax" ]; 67 68 meta = with lib; { 69 homepage = "https://blackjax-devs.github.io/blackjax"; 70 description = "Sampling library designed for ease of use, speed and modularity"; 71 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}"; 72 license = licenses.asl20; 73 maintainers = with maintainers; [ bcdarwin ]; 74 }; 75}