1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 absl-py,
12 chex,
13 distrax,
14 dm-env,
15 jax,
16 jaxlib,
17 numpy,
18 tensorflow-probability,
19
20 # tests
21 dm-haiku,
22 optax,
23 pytest-xdist,
24 pytestCheckHook,
25}:
26
27buildPythonPackage rec {
28 pname = "rlax";
29 version = "0.1.7";
30 pyproject = true;
31
32 src = fetchFromGitHub {
33 owner = "google-deepmind";
34 repo = "rlax";
35 tag = "v${version}";
36 hash = "sha256-w5vhXBMUlcqlLTKA58QgQ4pxyGs3etxJLIFUVPhE7H8=";
37 };
38
39 # TODO: remove these patches at the next release (already on master)
40 patches = [
41 (fetchpatch {
42 # Follow chex API change (https://github.com/google-deepmind/chex/pull/52)
43 name = "replace-deprecated-chex-assertions";
44 url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch";
45 hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs=";
46 })
47 (fetchpatch {
48 name = "remove-deprecation-warning";
49 url = "https://github.com/google-deepmind/rlax/commit/dea6eb479ffc32156aefe73015387a762c6b4562.patch";
50 hash = "sha256-htDyDRJW0eQx7AmrS3Fl7Lbh2VAmoYiDgHSePsQUaWs=";
51 })
52 (fetchpatch {
53 name = "fix-deprecation-warnings";
54 url = "https://github.com/google-deepmind/rlax/commit/605e0ef8ad8f9a06e88d4aabbb7d50e086d0cf3a.patch";
55 hash = "sha256-GZ/nGMXne6Lv6yDm/29NVTWxLBVSzaPYKAfQOLHY4UI=";
56 })
57 # https://github.com/google-deepmind/rlax/pull/135
58 (fetchpatch {
59 name = "fix-jax-0.6.0-compat";
60 url = "https://github.com/google-deepmind/rlax/commit/461b4cf9b4239d6b1b83aad6e5946f68d8402b93.patch";
61 hash = "sha256-uPMpm4IcoBWJwnyuIRjQEfo0F9HIW/lrwecxGW/Yw38=";
62 })
63 ];
64
65 build-system = [
66 setuptools
67 ];
68
69 dependencies = [
70 absl-py
71 chex
72 distrax
73 dm-env
74 jax
75 jaxlib
76 numpy
77 tensorflow-probability
78 ];
79
80 nativeCheckInputs = [
81 dm-haiku
82 optax
83 pytest-xdist
84 pytestCheckHook
85 ];
86
87 pythonImportsCheck = [ "rlax" ];
88
89 disabledTests = [
90 # AssertionError: Array(2, dtype=int32) != 0
91 "test_categorical_sample__with_device"
92 "test_categorical_sample__with_jit"
93 "test_categorical_sample__without_device"
94 "test_categorical_sample__without_jit"
95
96 # RuntimeError: Attempted to set 4 devices, but 1 CPUs already available:
97 # ensure that `set_n_cpu_devices` is executed before any JAX operation.
98 "test_cross_replica_scatter_add0"
99 "test_cross_replica_scatter_add1"
100 "test_cross_replica_scatter_add2"
101 "test_cross_replica_scatter_add3"
102 "test_cross_replica_scatter_add4"
103 "test_learn_scale_shift"
104 "test_normalize_unnormalize_is_identity"
105 "test_outputs_preserved"
106 "test_scale_bounded"
107 "test_slow_update"
108 "test_unnormalize_linear"
109 ];
110
111 meta = {
112 description = "Library of reinforcement learning building blocks in JAX";
113 homepage = "https://github.com/deepmind/rlax";
114 changelog = "https://github.com/google-deepmind/rlax/releases/tag/${src.tag}";
115 license = lib.licenses.asl20;
116 maintainers = with lib.maintainers; [ onny ];
117 };
118}