1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 ninja,
9 setuptools,
10 which,
11
12 # dependencies
13 cloudpickle,
14 numpy,
15 packaging,
16 tensordict,
17 torch,
18
19 # optional-dependencies
20 ale-py,
21 gym,
22 pygame,
23 torchsnapshot,
24 gymnasium,
25 mujoco,
26 h5py,
27 huggingface-hub,
28 minari,
29 pandas,
30 pillow,
31 requests,
32 scikit-learn,
33 torchvision,
34 tqdm,
35 moviepy,
36 git,
37 hydra-core,
38 tensorboard,
39 wandb,
40
41 # tests
42 imageio,
43 pytest-rerunfailures,
44 pytestCheckHook,
45 pyyaml,
46 scipy,
47}:
48
49buildPythonPackage rec {
50 pname = "torchrl";
51 version = "0.8.0";
52 pyproject = true;
53
54 src = fetchFromGitHub {
55 owner = "pytorch";
56 repo = "rl";
57 tag = "v${version}";
58 hash = "sha256-icT+QeA2FNhZjwD0ykui4aq5WswDv2i1QRh7dNlA4Cg=";
59 };
60
61 build-system = [
62 ninja
63 setuptools
64 which
65 ];
66
67 dependencies = [
68 cloudpickle
69 numpy
70 packaging
71 tensordict
72 torch
73 ];
74
75 optional-dependencies = {
76 atari = [
77 ale-py
78 gym
79 pygame
80 ];
81 checkpointing = [ torchsnapshot ];
82 gym-continuous = [
83 gymnasium
84 mujoco
85 ];
86 offline-data = [
87 h5py
88 huggingface-hub
89 minari
90 pandas
91 pillow
92 requests
93 scikit-learn
94 torchvision
95 tqdm
96 ];
97 rendering = [ moviepy ];
98 utils = [
99 git
100 hydra-core
101 tensorboard
102 tqdm
103 wandb
104 ];
105 };
106
107 # torchrl needs to create a folder to store datasets
108 preBuild = ''
109 export D4RL_DATASET_DIR=$(mktemp -d)
110 '';
111
112 pythonImportsCheck = [ "torchrl" ];
113
114 # We have to delete the source because otherwise it is used instead of the installed package.
115 preCheck = ''
116 rm -rf torchrl
117
118 export XDG_RUNTIME_DIR=$(mktemp -d)
119 '';
120
121 nativeCheckInputs =
122 [
123 h5py
124 gymnasium
125 imageio
126 pytest-rerunfailures
127 pytestCheckHook
128 pyyaml
129 scipy
130 torchvision
131 ]
132 ++ optional-dependencies.atari
133 ++ optional-dependencies.gym-continuous
134 ++ optional-dependencies.rendering;
135
136 disabledTests =
137 [
138 # torchrl is incompatible with gymnasium>=1.0
139 # https://github.com/pytorch/rl/discussions/2483
140 "test_resetting_strategies"
141 "test_torchrl_to_gym"
142 "test_vecenvs_nan"
143
144 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist.
145 "test_collector_run"
146 "test_transform_inverse"
147
148 # OSError: Unable to synchronously create file (unable to truncate a file which is already open)
149 "test_multi_env"
150 "test_simple_env"
151
152 # ImportWarning: Ignoring non-library in plugin directory:
153 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2
154 "test_auto_register"
155 "test_info_dict_reader"
156
157 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
158 "test_vecenvs_env"
159
160 # ValueError: Can't write images with one color channel.
161 "test_log_video"
162
163 # Those tests require the ALE environments (provided by unpackaged shimmy)
164 "test_collector_env_reset"
165 "test_gym"
166 "test_gym_fake_td"
167 "test_recorder"
168 "test_recorder_load"
169 "test_rollout"
170 "test_parallel_trans_env_check"
171 "test_serial_trans_env_check"
172 "test_single_trans_env_check"
173 "test_td_creation_from_spec"
174 "test_trans_parallel_env_check"
175 "test_trans_serial_env_check"
176 "test_transform_env"
177
178 # undeterministic
179 "test_distributed_collector_updatepolicy"
180 "test_timeit"
181
182 # On a 24 threads system
183 # assert torch.get_num_threads() == max(1, init_threads - 3)
184 # AssertionError: assert 23 == 21
185 "test_auto_num_threads"
186
187 # Flaky (hangs indefinitely on some CPUs)
188 "test_gae_multidim"
189 "test_gae_param_as_tensor"
190 ]
191 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
192 # Flaky
193 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2)
194 "test_vecnorm_parallel_auto"
195 ];
196
197 meta = {
198 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
199 homepage = "https://github.com/pytorch/rl";
200 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
201 license = lib.licenses.mit;
202 maintainers = with lib.maintainers; [ GaetanLepage ];
203 };
204}