1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 pythonOlder,
6 fetchFromGitHub,
7 pytest-xdist,
8 pytestCheckHook,
9 setuptools-scm,
10 fastprogress,
11 jax,
12 jaxlib,
13 jaxopt,
14 optax,
15 typing-extensions,
16}:
17
18buildPythonPackage rec {
19 pname = "blackjax";
20 version = "1.2.1";
21 pyproject = true;
22
23 disabled = pythonOlder "3.9";
24
25 src = fetchFromGitHub {
26 owner = "blackjax-devs";
27 repo = "blackjax";
28 rev = "refs/tags/${version}";
29 hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
30 };
31
32 build-system = [ setuptools-scm ];
33
34 dependencies = [
35 fastprogress
36 jax
37 jaxlib
38 jaxopt
39 optax
40 typing-extensions
41 ];
42
43 nativeCheckInputs = [
44 pytestCheckHook
45 pytest-xdist
46 ];
47
48 disabledTestPaths =
49 [ "tests/test_benchmarks.py" ]
50 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
51 # Assertion errors on numerical values
52 "tests/mcmc/test_integrators.py"
53 ];
54
55 disabledTests =
56 [
57 # too slow
58 "test_adaptive_tempered_smc"
59 ]
60 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
61 # Numerical test (AssertionError)
62 # https://github.com/blackjax-devs/blackjax/issues/668
63 "test_chees_adaptation"
64 ];
65
66 pythonImportsCheck = [ "blackjax" ];
67
68 meta = with lib; {
69 homepage = "https://blackjax-devs.github.io/blackjax";
70 description = "Sampling library designed for ease of use, speed and modularity";
71 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}";
72 license = licenses.asl20;
73 maintainers = with maintainers; [ bcdarwin ];
74 };
75}