1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 ninja,
8 setuptools,
9 which,
10
11 # dependencies
12 cloudpickle,
13 numpy,
14 packaging,
15 tensordict,
16 torch,
17
18 # optional-dependencies
19 ale-py,
20 gym,
21 pygame,
22 torchsnapshot,
23 gymnasium,
24 mujoco,
25 h5py,
26 huggingface-hub,
27 minari,
28 pandas,
29 pillow,
30 requests,
31 scikit-learn,
32 torchvision,
33 tqdm,
34 moviepy,
35 git,
36 hydra-core,
37 tensorboard,
38 wandb,
39
40 # checks
41 imageio,
42 pytest-rerunfailures,
43 pytestCheckHook,
44 pyyaml,
45 scipy,
46}:
47
48buildPythonPackage rec {
49 pname = "torchrl";
50 version = "0.5.0";
51 pyproject = true;
52
53 src = fetchFromGitHub {
54 owner = "pytorch";
55 repo = "rl";
56 rev = "refs/tags/v${version}";
57 hash = "sha256-uDpOdOuHTqKFKspHOpl84kD9adEKZjvO2GnYuL27H5c=";
58 };
59
60 build-system = [
61 ninja
62 setuptools
63 which
64 ];
65
66 dependencies = [
67 cloudpickle
68 numpy
69 packaging
70 tensordict
71 torch
72 ];
73
74 optional-dependencies = {
75 atari = [
76 ale-py
77 gym
78 pygame
79 ];
80 checkpointing = [ torchsnapshot ];
81 gym-continuous = [
82 gymnasium
83 mujoco
84 ];
85 offline-data = [
86 h5py
87 huggingface-hub
88 minari
89 pandas
90 pillow
91 requests
92 scikit-learn
93 torchvision
94 tqdm
95 ];
96 rendering = [ moviepy ];
97 utils = [
98 git
99 hydra-core
100 tensorboard
101 tqdm
102 wandb
103 ];
104 };
105
106 # torchrl needs to create a folder to store datasets
107 preBuild = ''
108 export D4RL_DATASET_DIR=$(mktemp -d)
109 '';
110
111 pythonImportsCheck = [ "torchrl" ];
112
113 # We have to delete the source because otherwise it is used instead of the installed package.
114 preCheck = ''
115 rm -rf torchrl
116
117 export XDG_RUNTIME_DIR=$(mktemp -d)
118 '';
119
120 nativeCheckInputs =
121 [
122 h5py
123 gymnasium
124 imageio
125 pytest-rerunfailures
126 pytestCheckHook
127 pyyaml
128 scipy
129 torchvision
130 ]
131 ++ optional-dependencies.atari
132 ++ optional-dependencies.gym-continuous
133 ++ optional-dependencies.rendering;
134
135 disabledTests = [
136 # torchrl is incompatible with gymnasium>=1.0
137 # https://github.com/pytorch/rl/discussions/2483
138 "test_resetting_strategies"
139 "test_torchrl_to_gym"
140
141 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
142 "test_vecenvs_env"
143
144 # ValueError: Can't write images with one color channel.
145 "test_log_video"
146
147 # Those tests require the ALE environments (provided by unpackaged shimmy)
148 "test_collector_env_reset"
149 "test_gym"
150 "test_gym_fake_td"
151 "test_recorder"
152 "test_recorder_load"
153 "test_rollout"
154 "test_parallel_trans_env_check"
155 "test_serial_trans_env_check"
156 "test_single_trans_env_check"
157 "test_td_creation_from_spec"
158 "test_trans_parallel_env_check"
159 "test_trans_serial_env_check"
160 "test_transform_env"
161
162 # undeterministic
163 "test_distributed_collector_updatepolicy"
164 "test_timeit"
165
166 # On a 24 threads system
167 # assert torch.get_num_threads() == max(1, init_threads - 3)
168 # AssertionError: assert 23 == 21
169 "test_auto_num_threads"
170
171 # Flaky (hangs indefinitely on some CPUs)
172 "test_gae_multidim"
173 "test_gae_param_as_tensor"
174 ];
175
176 meta = {
177 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
178 homepage = "https://github.com/pytorch/rl";
179 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
180 license = lib.licenses.mit;
181 maintainers = with lib.maintainers; [ GaetanLepage ];
182 };
183}