1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6 absl-py,
7 chex,
8 distrax,
9 dm-env,
10 jax,
11 jaxlib,
12 numpy,
13 tensorflow-probability,
14 dm-haiku,
15 optax,
16 pytest-xdist,
17 pytestCheckHook,
18}:
19
20buildPythonPackage rec {
21 pname = "rlax";
22 version = "0.1.6";
23 format = "setuptools";
24
25 src = fetchFromGitHub {
26 owner = "google-deepmind";
27 repo = "rlax";
28 rev = "refs/tags/v${version}";
29 hash = "sha256-v2Lbzya+E9d7tlUVlQQa4fuPp2q3E309Qvyt70mcdb0=";
30 };
31
32 patches = [
33 (fetchpatch {
34 # Follow chex API change (https://github.com/google-deepmind/chex/pull/52)
35 name = "replace-deprecated-chex-assertions";
36 url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch";
37 hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs=";
38 })
39 ];
40
41 propagatedBuildInputs = [
42 absl-py
43 chex
44 distrax
45 dm-env
46 jax
47 jaxlib
48 numpy
49 tensorflow-probability
50 ];
51
52 nativeCheckInputs = [
53 dm-haiku
54 optax
55 pytest-xdist
56 pytestCheckHook
57 ];
58
59 pythonImportsCheck = [ "rlax" ];
60
61 disabledTests = [
62 # RuntimeErrors
63 "test_cross_replica_scatter_add0"
64 "test_cross_replica_scatter_add1"
65 "test_cross_replica_scatter_add2"
66 "test_cross_replica_scatter_add3"
67 "test_cross_replica_scatter_add4"
68 "test_learn_scale_shift"
69 "test_normalize_unnormalize_is_identity"
70 "test_outputs_preserved"
71 "test_scale_bounded"
72 "test_slow_update"
73 "test_unnormalize_linear"
74 ];
75
76 meta = with lib; {
77 description = "Library of reinforcement learning building blocks in JAX";
78 homepage = "https://github.com/deepmind/rlax";
79 license = licenses.asl20;
80 maintainers = with maintainers; [ onny ];
81 };
82}