1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 ninja,
9 setuptools,
10 which,
11
12 # dependencies
13 cloudpickle,
14 numpy,
15 packaging,
16 tensordict,
17 torch,
18
19 # optional-dependencies
20 # atari
21 gymnasium,
22 # checkpointing
23 torchsnapshot,
24 # gym-continuous
25 mujoco,
26 # llm
27 accelerate,
28 datasets,
29 einops,
30 immutabledict,
31 langdetect,
32 nltk,
33 playwright,
34 protobuf,
35 safetensors,
36 sentencepiece,
37 transformers,
38 vllm,
39 # offline-data
40 h5py,
41 huggingface-hub,
42 minari,
43 pandas,
44 pillow,
45 requests,
46 scikit-learn,
47 torchvision,
48 tqdm,
49 # rendering
50 moviepy,
51 # utils
52 git,
53 hydra-core,
54 tensorboard,
55 wandb,
56
57 # tests
58 imageio,
59 pytest-rerunfailures,
60 pytestCheckHook,
61 pyyaml,
62 scipy,
63}:
64
65buildPythonPackage rec {
66 pname = "torchrl";
67 version = "0.9.2";
68 pyproject = true;
69
70 src = fetchFromGitHub {
71 owner = "pytorch";
72 repo = "rl";
73 tag = "v${version}";
74 hash = "sha256-6rU5+J70T0E7+60jihsjwlLls8jJlxKi3nmrL0xm2c0=";
75 };
76
77 build-system = [
78 ninja
79 setuptools
80 which
81 ];
82
83 dependencies = [
84 cloudpickle
85 numpy
86 packaging
87 tensordict
88 torch
89 ];
90
91 optional-dependencies = {
92 atari = gymnasium.optional-dependencies.atari;
93 checkpointing = [ torchsnapshot ];
94 gym-continuous = [
95 gymnasium
96 mujoco
97 ];
98 llm = [
99 accelerate
100 datasets
101 einops
102 immutabledict
103 langdetect
104 nltk
105 playwright
106 protobuf
107 safetensors
108 sentencepiece
109 transformers
110 vllm
111 ];
112 offline-data = [
113 h5py
114 huggingface-hub
115 minari
116 pandas
117 pillow
118 requests
119 scikit-learn
120 torchvision
121 tqdm
122 ];
123 rendering = [ moviepy ];
124 utils = [
125 git
126 hydra-core
127 tensorboard
128 tqdm
129 wandb
130 ];
131 };
132
133 # torchrl needs to create a folder to store datasets
134 preBuild = ''
135 export D4RL_DATASET_DIR=$(mktemp -d)
136 '';
137
138 pythonImportsCheck = [ "torchrl" ];
139
140 # We have to delete the source because otherwise it is used instead of the installed package.
141 preCheck = ''
142 rm -rf torchrl
143
144 export XDG_RUNTIME_DIR=$(mktemp -d)
145 '';
146
147 nativeCheckInputs = [
148 h5py
149 gymnasium
150 imageio
151 pytest-rerunfailures
152 pytestCheckHook
153 pyyaml
154 scipy
155 torchvision
156 ]
157 ++ optional-dependencies.atari
158 ++ optional-dependencies.gym-continuous
159 ++ optional-dependencies.llm
160 ++ optional-dependencies.rendering;
161
162 disabledTests = [
163 # Require network
164 "test_create_or_load_dataset"
165 "test_from_text_env_tokenizer"
166 "test_from_text_env_tokenizer_catframes"
167 "test_from_text_rb_slicesampler"
168 "test_generate"
169 "test_get_dataloader"
170 "test_get_scores"
171 "test_preproc_data"
172 "test_prompt_tensordict_tokenizer"
173 "test_reward_model"
174 "test_tensordict_tokenizer"
175 "test_transform_compose"
176 "test_transform_model"
177 "test_transform_no_env"
178 "test_transform_rb"
179
180 # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
181 "TestRayCollector"
182
183 # torchrl is incompatible with gymnasium>=1.0
184 # https://github.com/pytorch/rl/discussions/2483
185 "test_resetting_strategies"
186 "test_torchrl_to_gym"
187 "test_vecenvs_nan"
188
189 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist.
190 "test_collector_run"
191 "test_transform_inverse"
192
193 # OSError: Unable to synchronously create file (unable to truncate a file which is already open)
194 "test_multi_env"
195 "test_simple_env"
196
197 # ImportWarning: Ignoring non-library in plugin directory:
198 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2
199 "test_auto_register"
200 "test_info_dict_reader"
201
202 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
203 "test_vecenvs_env"
204
205 # ValueError: Can't write images with one color channel.
206 "test_log_video"
207
208 # Those tests require the ALE environments (provided by unpackaged shimmy)
209 "test_collector_env_reset"
210 "test_gym"
211 "test_gym_fake_td"
212 "test_recorder"
213 "test_recorder_load"
214 "test_rollout"
215 "test_parallel_trans_env_check"
216 "test_serial_trans_env_check"
217 "test_single_trans_env_check"
218 "test_td_creation_from_spec"
219 "test_trans_parallel_env_check"
220 "test_trans_serial_env_check"
221 "test_transform_env"
222
223 # undeterministic
224 "test_distributed_collector_updatepolicy"
225 "test_timeit"
226
227 # On a 24 threads system
228 # assert torch.get_num_threads() == max(1, init_threads - 3)
229 # AssertionError: assert 23 == 21
230 "test_auto_num_threads"
231
232 # Flaky (hangs indefinitely on some CPUs)
233 "test_gae_multidim"
234 "test_gae_param_as_tensor"
235 ]
236 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
237 # Flaky
238 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2)
239 "test_vecnorm_parallel_auto"
240 ];
241
242 disabledTestPaths = [
243 # ERROR collecting test/smoke_test.py
244 # import file mismatch:
245 # imported module 'smoke_test' has this __file__ attribute:
246 # /build/source/test/llm/smoke_test.py
247 # which is not the same as the test file we want to collect:
248 # /build/source/test/smoke_test.py
249 "test/llm"
250 ];
251
252 meta = {
253 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
254 homepage = "https://github.com/pytorch/rl";
255 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
256 license = lib.licenses.mit;
257 maintainers = with lib.maintainers; [ GaetanLepage ];
258 };
259}