1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchFromGitHub
5, fetchpatch
6, pytestCheckHook
7, fastprogress
8, jax
9, jaxlib
10, jaxopt
11, optax
12, typing-extensions
13}:
14
15buildPythonPackage rec {
16 pname = "blackjax";
17 version = "0.9.6";
18 disabled = pythonOlder "3.7";
19
20 src = fetchFromGitHub {
21 owner = "blackjax-devs";
22 repo = pname;
23 rev = "refs/tags/${version}";
24 hash = "sha256-EieDu9SJxi2cp1bHlxX4vvFZeDGMGIm24GoR8nSyjvE=";
25 };
26
27 patches = [
28 # remove in next release
29 (fetchpatch {
30 name = "fix-lbfgs-args";
31 url = "https://github.com/blackjax-devs/blackjax/commit/1aaa6f64bbcb0557b658604b2daba826e260cbc6.patch";
32 hash = "sha256-XyjorXPH5Ap35Tv1/lTeTWamjplJF29SsvOq59ypftE=";
33 })
34 ];
35
36 propagatedBuildInputs = [
37 fastprogress
38 jax
39 jaxlib
40 jaxopt
41 optax
42 typing-extensions
43 ];
44
45 nativeCheckInputs = [ pytestCheckHook ];
46 disabledTestPaths = [ "tests/test_benchmarks.py" ];
47 disabledTests = [
48 # too slow
49 "test_adaptive_tempered_smc"
50 ];
51
52 pythonImportsCheck = [
53 "blackjax"
54 ];
55
56 meta = with lib; {
57 homepage = "https://blackjax-devs.github.io/blackjax";
58 description = "Sampling library designed for ease of use, speed and modularity";
59 license = licenses.asl20;
60 maintainers = with maintainers; [ bcdarwin ];
61 };
62}