1{ lib
2, fetchPypi
3, buildPythonPackage
4, numpy
5, tensorflow-probability
6, chex
7, dm-haiku
8, pytestCheckHook
9, jaxlib
10}:
11
12buildPythonPackage rec {
13 pname = "distrax";
14 version = "0.1.4";
15
16 src = fetchPypi {
17 inherit pname version;
18 hash = "sha256-klXT5wfnWUGMrf5sQhYqz7Foc/Ou5y4GIFgtTff1ZFQ=";
19 };
20
21 buildInputs = [
22 chex
23 jaxlib
24 numpy
25 tensorflow-probability
26 ];
27
28 nativeCheckInputs = [
29 dm-haiku
30 pytestCheckHook
31 ];
32
33 pythonImportsCheck = [
34 "distrax"
35 ];
36
37 disabledTestPaths = [
38 # TypeErrors
39 "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
40 "distrax/_src/distributions/distribution_from_tfp_test.py"
41 "distrax/_src/distributions/laplace_test.py"
42 "distrax/_src/distributions/multinomial_test.py"
43 "distrax/_src/distributions/mvn_diag_plus_low_rank_test.py"
44 "distrax/_src/distributions/mvn_kl_test.py"
45 "distrax/_src/distributions/straight_through_test.py"
46 "distrax/_src/distributions/tfp_compatible_distribution_test.py"
47 "distrax/_src/distributions/transformed_test.py"
48 "distrax/_src/distributions/uniform_test.py"
49 "distrax/_src/utils/transformations_test.py"
50 ];
51
52 meta = with lib; {
53 description = "Probability distributions in JAX";
54 homepage = "https://github.com/deepmind/distrax";
55 license = licenses.asl20;
56 maintainers = with maintainers; [ onny ];
57 # Broken on all platforms (starting 2022-07-27)
58 broken = true;
59 };
60}