1{ lib
2, fetchPypi
3, buildPythonPackage
4, numpy
5, tensorflow-probability
6, chex
7, dm-haiku
8, pytestCheckHook
9, jaxlib }:
10
11buildPythonPackage rec {
12 pname = "distrax";
13 version = "0.1.2";
14
15 src = fetchPypi {
16 inherit pname version;
17 sha256 = "sha256-b/+rxjdowNMuhUBhRCuN45z/iUbj1hN1qCSQqqAtZIw=";
18 };
19
20 buildInputs = [
21 chex
22 jaxlib
23 numpy
24 tensorflow-probability
25 ];
26
27 checkInputs = [
28 dm-haiku
29 pytestCheckHook
30 ];
31
32 pythonImportsCheck = [
33 "distrax"
34 ];
35
36 disabledTestPaths = [
37 # TypeErrors
38 "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
39 "distrax/_src/distributions/distribution_from_tfp_test.py"
40 "distrax/_src/distributions/laplace_test.py"
41 "distrax/_src/distributions/multinomial_test.py"
42 "distrax/_src/distributions/mvn_diag_plus_low_rank_test.py"
43 "distrax/_src/distributions/mvn_kl_test.py"
44 "distrax/_src/distributions/straight_through_test.py"
45 "distrax/_src/distributions/tfp_compatible_distribution_test.py"
46 "distrax/_src/distributions/transformed_test.py"
47 "distrax/_src/distributions/uniform_test.py"
48 "distrax/_src/utils/transformations_test.py"
49 ];
50
51 meta = with lib; {
52 description = "Probability distributions in JAX";
53 homepage = "https://github.com/deepmind/distrax";
54 license = licenses.asl20;
55 maintainers = with maintainers; [ onny ];
56 };
57}