1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools-scm, 8 9 # dependencies 10 fastprogress, 11 jax, 12 jaxlib, 13 jaxopt, 14 optax, 15 typing-extensions, 16 17 # checks 18 pytestCheckHook, 19 pytest-xdist, 20 21 stdenv, 22}: 23 24buildPythonPackage rec { 25 pname = "blackjax"; 26 version = "1.2.5"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "blackjax-devs"; 31 repo = "blackjax"; 32 tag = version; 33 hash = "sha256-2GTjKjLIWFaluTjdWdUF9Iim973y81xv715xspghRZI="; 34 }; 35 36 build-system = [ setuptools-scm ]; 37 38 dependencies = [ 39 fastprogress 40 jax 41 jaxlib 42 jaxopt 43 optax 44 typing-extensions 45 ]; 46 47 nativeCheckInputs = [ 48 pytestCheckHook 49 pytest-xdist 50 ]; 51 52 pytestFlagsArray = [ 53 # DeprecationWarning: JAXopt is no longer maintained 54 "-W" 55 "ignore::DeprecationWarning" 56 ]; 57 58 disabledTestPaths = [ 59 "tests/test_benchmarks.py" 60 61 # Assertion errors on numerical values 62 "tests/mcmc/test_integrators.py" 63 ]; 64 65 disabledTests = 66 [ 67 # too slow 68 "test_adaptive_tempered_smc" 69 70 # AssertionError on numerical values 71 "test_barker" 72 "test_mclmc" 73 "test_mcse4" 74 "test_normal_univariate" 75 "test_nuts__with_device" 76 "test_nuts__with_jit" 77 "test_nuts__without_device" 78 "test_nuts__without_jit" 79 "test_smc_waste_free__with_jit" 80 ] 81 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ 82 # Numerical test (AssertionError) 83 # https://github.com/blackjax-devs/blackjax/issues/668 84 "test_chees_adaptation" 85 ]; 86 87 pythonImportsCheck = [ "blackjax" ]; 88 89 meta = { 90 homepage = "https://blackjax-devs.github.io/blackjax"; 91 description = "Sampling library designed for ease of use, speed and modularity"; 92 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}"; 93 license = lib.licenses.asl20; 94 maintainers = with lib.maintainers; [ bcdarwin ]; 95 }; 96}