1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools-scm,
8
9 # dependencies
10 fastprogress,
11 jax,
12 jaxlib,
13 jaxopt,
14 optax,
15 typing-extensions,
16
17 # checks
18 pytestCheckHook,
19 pytest-xdist,
20
21 stdenv,
22}:
23
24buildPythonPackage rec {
25 pname = "blackjax";
26 version = "1.2.4";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "blackjax-devs";
31 repo = "blackjax";
32 rev = "refs/tags/${version}";
33 hash = "sha256-qaQBbRAKExRHr4Uhm5/Q1Ydon6ePsjG2PWbwSdR9QZM=";
34 };
35
36 build-system = [ setuptools-scm ];
37
38 dependencies = [
39 fastprogress
40 jax
41 jaxlib
42 jaxopt
43 optax
44 typing-extensions
45 ];
46
47 nativeCheckInputs = [
48 pytestCheckHook
49 pytest-xdist
50 ];
51
52 disabledTestPaths =
53 [ "tests/test_benchmarks.py" ]
54 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
55 # Assertion errors on numerical values
56 "tests/mcmc/test_integrators.py"
57 ];
58
59 disabledTests =
60 [
61 # too slow
62 "test_adaptive_tempered_smc"
63 ]
64 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
65 # Numerical test (AssertionError)
66 # https://github.com/blackjax-devs/blackjax/issues/668
67 "test_chees_adaptation"
68 ];
69
70 pythonImportsCheck = [ "blackjax" ];
71
72 meta = {
73 homepage = "https://blackjax-devs.github.io/blackjax";
74 description = "Sampling library designed for ease of use, speed and modularity";
75 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}";
76 license = lib.licenses.asl20;
77 maintainers = with lib.maintainers; [ bcdarwin ];
78 };
79}