1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 chex,
7 jaxlib,
8 numpy,
9 tensorflow-probability,
10 dm-haiku,
11 pytest-xdist,
12 pytestCheckHook,
13}:
14
15buildPythonPackage rec {
16 pname = "distrax";
17 version = "0.1.5";
18 pyproject = true;
19
20 disabled = pythonOlder "3.9";
21
22 src = fetchFromGitHub {
23 owner = "google-deepmind";
24 repo = "distrax";
25 rev = "refs/tags/v${version}";
26 hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0=";
27 };
28
29 buildInputs = [
30 chex
31 jaxlib
32 numpy
33 tensorflow-probability
34 ];
35
36 nativeCheckInputs = [
37 dm-haiku
38 pytest-xdist
39 pytestCheckHook
40 ];
41
42 pythonImportsCheck = [ "distrax" ];
43
44 disabledTests = [
45 # AssertionError on numerical values
46 # Reported upstream in https://github.com/google-deepmind/distrax/issues/267
47 "test_method_with_input_unnormalized_probs__with_device"
48 "test_method_with_input_unnormalized_probs__with_jit"
49 "test_method_with_input_unnormalized_probs__without_device"
50 "test_method_with_input_unnormalized_probs__without_jit"
51 "test_method_with_value_1d"
52 "test_nested_distributions__with_device"
53 "test_nested_distributions__without_device"
54 "test_nested_distributions__with_jit"
55 "test_nested_distributions__without_jit"
56 "test_stability__with_device"
57 "test_stability__with_jit"
58 "test_stability__without_device"
59 "test_stability__without_jit"
60 "test_von_mises_sample_gradient"
61 "test_von_mises_sample_moments"
62 ];
63
64 disabledTestPaths = [
65 # TypeErrors
66 "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
67 "distrax/_src/distributions/distribution_from_tfp_test.py"
68 "distrax/_src/distributions/laplace_test.py"
69 "distrax/_src/distributions/multinomial_test.py"
70 "distrax/_src/distributions/mvn_diag_plus_low_rank_test.py"
71 "distrax/_src/distributions/mvn_kl_test.py"
72 "distrax/_src/distributions/straight_through_test.py"
73 "distrax/_src/distributions/tfp_compatible_distribution_test.py"
74 "distrax/_src/distributions/transformed_test.py"
75 "distrax/_src/distributions/uniform_test.py"
76 "distrax/_src/utils/transformations_test.py"
77 ];
78
79 meta = with lib; {
80 description = "Probability distributions in JAX";
81 homepage = "https://github.com/deepmind/distrax";
82 license = licenses.asl20;
83 maintainers = with maintainers; [ onny ];
84 # Several tests fail with:
85 # AssertionError: [Chex] Assertion assert_type failed: Error in type compatibility check
86 broken = true;
87 };
88}