1{ lib
2, fetchPypi
3, buildPythonPackage
4, chex
5, jaxlib
6, tensorflow-probability
7, optax
8, dm-haiku
9, bsuite
10, frozendict
11, pytestCheckHook
12, dm-env
13, distrax }:
14
15buildPythonPackage rec {
16 pname = "rlax";
17 version = "0.1.6";
18
19 src = fetchPypi {
20 inherit pname version;
21 hash = "sha256-C3nFOv/zxvAoz6WZ0RAZffzEbxIx/XrGabO4QPxrik8=";
22 };
23
24 buildInputs = [
25 chex
26 jaxlib
27 distrax
28 tensorflow-probability
29 ];
30
31 nativeCheckInputs = [
32 bsuite
33 dm-env
34 dm-haiku
35 frozendict
36 optax
37 pytestCheckHook
38 ];
39
40 pythonImportsCheck = [
41 "rlax"
42 ];
43
44 disabledTests = [
45 # RuntimeErrors
46 "test_cross_replica_scatter_add0"
47 "test_cross_replica_scatter_add1"
48 "test_cross_replica_scatter_add2"
49 "test_cross_replica_scatter_add3"
50 "test_cross_replica_scatter_add4"
51 "test_learn_scale_shift"
52 "test_normalize_unnormalize_is_identity"
53 "test_outputs_preserved"
54 "test_scale_bounded"
55 "test_slow_update"
56 "test_unnormalize_linear"
57 ];
58
59 meta = with lib; {
60 description = "Library of reinforcement learning building blocks in JAX";
61 homepage = "https://github.com/deepmind/rlax";
62 license = licenses.asl20;
63 maintainers = with maintainers; [ onny ];
64 };
65}