1{
2 lib,
3 buildPythonPackage,
4 pythonOlder,
5 fetchFromGitHub,
6 ninja,
7 setuptools,
8 wheel,
9 which,
10 cloudpickle,
11 numpy,
12 torch,
13 ale-py,
14 gym,
15 pygame,
16 torchsnapshot,
17 gymnasium,
18 mujoco,
19 h5py,
20 huggingface-hub,
21 minari,
22 pandas,
23 pillow,
24 requests,
25 scikit-learn,
26 torchvision,
27 tqdm,
28 moviepy,
29 git,
30 hydra-core,
31 tensorboard,
32 wandb,
33 packaging,
34 tensordict,
35 imageio,
36 pytest-rerunfailures,
37 pytestCheckHook,
38 pyyaml,
39 scipy,
40 stdenv,
41}:
42
43buildPythonPackage rec {
44 pname = "torchrl";
45 version = "0.4.0";
46 pyproject = true;
47
48 disabled = pythonOlder "3.8";
49
50 src = fetchFromGitHub {
51 owner = "pytorch";
52 repo = "rl";
53 rev = "refs/tags/v${version}";
54 hash = "sha256-8wSyyErqveP9zZS/UGvWVBYyylu9BuA447GEjXIzBIk=";
55 };
56
57 build-system = [
58 ninja
59 setuptools
60 wheel
61 which
62 ];
63
64 dependencies = [
65 cloudpickle
66 numpy
67 packaging
68 tensordict
69 torch
70 ];
71
72 passthru.optional-dependencies = {
73 atari = [
74 ale-py
75 gym
76 pygame
77 ];
78 checkpointing = [ torchsnapshot ];
79 gym-continuous = [
80 gymnasium
81 mujoco
82 ];
83 offline-data = [
84 h5py
85 huggingface-hub
86 minari
87 pandas
88 pillow
89 requests
90 scikit-learn
91 torchvision
92 tqdm
93 ];
94 rendering = [ moviepy ];
95 utils = [
96 git
97 hydra-core
98 tensorboard
99 tqdm
100 wandb
101 ];
102 };
103
104 # torchrl needs to create a folder to store datasets
105 preBuild = ''
106 export D4RL_DATASET_DIR=$(mktemp -d)
107 '';
108
109 pythonImportsCheck = [ "torchrl" ];
110
111 # We have to delete the source because otherwise it is used instead of the installed package.
112 preCheck = ''
113 rm -rf torchrl
114
115 export XDG_RUNTIME_DIR=$(mktemp -d)
116 '';
117
118 nativeCheckInputs =
119 [
120 gymnasium
121 imageio
122 pytest-rerunfailures
123 pytestCheckHook
124 pyyaml
125 scipy
126 torchvision
127 ]
128 ++ passthru.optional-dependencies.atari
129 ++ passthru.optional-dependencies.gym-continuous
130 ++ passthru.optional-dependencies.rendering;
131
132 disabledTests = [
133 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
134 "test_vecenvs_env"
135
136 # ValueError: Can't write images with one color channel.
137 "test_log_video"
138
139 # Those tests require the ALE environments (provided by unpackaged shimmy)
140 "test_collector_env_reset"
141 "test_gym"
142 "test_gym_fake_td"
143 "test_recorder"
144 "test_recorder_load"
145 "test_rollout"
146 "test_parallel_trans_env_check"
147 "test_serial_trans_env_check"
148 "test_single_trans_env_check"
149 "test_td_creation_from_spec"
150 "test_trans_parallel_env_check"
151 "test_trans_serial_env_check"
152 "test_transform_env"
153
154 # undeterministic
155 "test_distributed_collector_updatepolicy"
156 "test_timeit"
157 ];
158
159 meta = with lib; {
160 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
161 homepage = "https://github.com/pytorch/rl";
162 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
163 license = licenses.mit;
164 maintainers = with maintainers; [ GaetanLepage ];
165 # ~3k tests fail with: RuntimeError: internal error
166 broken = stdenv.isLinux && stdenv.isAarch64;
167 };
168}