1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchFromGitHub
5, pytestCheckHook
6, setuptools-scm
7, fastprogress
8, jax
9, jaxlib
10, jaxopt
11, optax
12, typing-extensions
13}:
14
15buildPythonPackage rec {
16 pname = "blackjax";
17 version = "1.0.0";
18 pyproject = true;
19
20 disabled = pythonOlder "3.8";
21
22 src = fetchFromGitHub {
23 owner = "blackjax-devs";
24 repo = pname;
25 rev = "refs/tags/${version}";
26 hash = "sha256-hqOKSHyZ/BmOu6MJLeecD3H1BbLbZqywmlBzn3xjQRk=";
27 };
28
29 nativeBuildInputs = [ setuptools-scm ];
30
31 env.SETUPTOOLS_SCM_PRETEND_VERSION = version;
32
33 propagatedBuildInputs = [
34 fastprogress
35 jax
36 jaxlib
37 jaxopt
38 optax
39 typing-extensions
40 ];
41
42 nativeCheckInputs = [ pytestCheckHook ];
43 disabledTestPaths = [ "tests/test_benchmarks.py" ];
44 disabledTests = [
45 # too slow
46 "test_adaptive_tempered_smc"
47 ];
48
49 pythonImportsCheck = [
50 "blackjax"
51 ];
52
53 meta = with lib; {
54 homepage = "https://blackjax-devs.github.io/blackjax";
55 description = "Sampling library designed for ease of use, speed and modularity";
56 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}";
57 license = licenses.asl20;
58 maintainers = with maintainers; [ bcdarwin ];
59 };
60}