Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at r-updates 5.8 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 ninja, 9 setuptools, 10 which, 11 12 # dependencies 13 cloudpickle, 14 numpy, 15 packaging, 16 tensordict, 17 torch, 18 19 # optional-dependencies 20 # atari 21 gymnasium, 22 # checkpointing 23 torchsnapshot, 24 # gym-continuous 25 mujoco, 26 # llm 27 accelerate, 28 datasets, 29 einops, 30 immutabledict, 31 langdetect, 32 nltk, 33 playwright, 34 protobuf, 35 safetensors, 36 sentencepiece, 37 transformers, 38 vllm, 39 # offline-data 40 h5py, 41 huggingface-hub, 42 minari, 43 pandas, 44 pillow, 45 requests, 46 scikit-learn, 47 torchvision, 48 tqdm, 49 # rendering 50 moviepy, 51 # utils 52 git, 53 hydra-core, 54 tensorboard, 55 wandb, 56 57 # tests 58 imageio, 59 pytest-rerunfailures, 60 pytestCheckHook, 61 pyyaml, 62 scipy, 63}: 64 65buildPythonPackage rec { 66 pname = "torchrl"; 67 version = "0.9.2"; 68 pyproject = true; 69 70 src = fetchFromGitHub { 71 owner = "pytorch"; 72 repo = "rl"; 73 tag = "v${version}"; 74 hash = "sha256-6rU5+J70T0E7+60jihsjwlLls8jJlxKi3nmrL0xm2c0="; 75 }; 76 77 build-system = [ 78 ninja 79 setuptools 80 which 81 ]; 82 83 dependencies = [ 84 cloudpickle 85 numpy 86 packaging 87 tensordict 88 torch 89 ]; 90 91 optional-dependencies = { 92 atari = gymnasium.optional-dependencies.atari; 93 checkpointing = [ torchsnapshot ]; 94 gym-continuous = [ 95 gymnasium 96 mujoco 97 ]; 98 llm = [ 99 accelerate 100 datasets 101 einops 102 immutabledict 103 langdetect 104 nltk 105 playwright 106 protobuf 107 safetensors 108 sentencepiece 109 transformers 110 vllm 111 ]; 112 offline-data = [ 113 h5py 114 huggingface-hub 115 minari 116 pandas 117 pillow 118 requests 119 scikit-learn 120 torchvision 121 tqdm 122 ]; 123 rendering = [ moviepy ]; 124 utils = [ 125 git 126 hydra-core 127 tensorboard 128 tqdm 129 wandb 130 ]; 131 }; 132 133 # torchrl needs to create a folder to store datasets 134 preBuild = '' 135 export D4RL_DATASET_DIR=$(mktemp -d) 136 ''; 137 138 pythonImportsCheck = [ "torchrl" ]; 139 140 # We have to delete the source because otherwise it is used instead of the installed package. 141 preCheck = '' 142 rm -rf torchrl 143 144 export XDG_RUNTIME_DIR=$(mktemp -d) 145 ''; 146 147 nativeCheckInputs = [ 148 h5py 149 gymnasium 150 imageio 151 pytest-rerunfailures 152 pytestCheckHook 153 pyyaml 154 scipy 155 torchvision 156 ] 157 ++ optional-dependencies.atari 158 ++ optional-dependencies.gym-continuous 159 ++ optional-dependencies.llm 160 ++ optional-dependencies.rendering; 161 162 disabledTests = [ 163 # Require network 164 "test_create_or_load_dataset" 165 "test_from_text_env_tokenizer" 166 "test_from_text_env_tokenizer_catframes" 167 "test_from_text_rb_slicesampler" 168 "test_generate" 169 "test_get_dataloader" 170 "test_get_scores" 171 "test_preproc_data" 172 "test_prompt_tensordict_tokenizer" 173 "test_reward_model" 174 "test_tensordict_tokenizer" 175 "test_transform_compose" 176 "test_transform_model" 177 "test_transform_no_env" 178 "test_transform_rb" 179 180 # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment 181 "TestRayCollector" 182 183 # torchrl is incompatible with gymnasium>=1.0 184 # https://github.com/pytorch/rl/discussions/2483 185 "test_resetting_strategies" 186 "test_torchrl_to_gym" 187 "test_vecenvs_nan" 188 189 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist. 190 "test_collector_run" 191 "test_transform_inverse" 192 193 # OSError: Unable to synchronously create file (unable to truncate a file which is already open) 194 "test_multi_env" 195 "test_simple_env" 196 197 # ImportWarning: Ignoring non-library in plugin directory: 198 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2 199 "test_auto_register" 200 "test_info_dict_reader" 201 202 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called 203 "test_vecenvs_env" 204 205 # ValueError: Can't write images with one color channel. 206 "test_log_video" 207 208 # Those tests require the ALE environments (provided by unpackaged shimmy) 209 "test_collector_env_reset" 210 "test_gym" 211 "test_gym_fake_td" 212 "test_recorder" 213 "test_recorder_load" 214 "test_rollout" 215 "test_parallel_trans_env_check" 216 "test_serial_trans_env_check" 217 "test_single_trans_env_check" 218 "test_td_creation_from_spec" 219 "test_trans_parallel_env_check" 220 "test_trans_serial_env_check" 221 "test_transform_env" 222 223 # undeterministic 224 "test_distributed_collector_updatepolicy" 225 "test_timeit" 226 227 # On a 24 threads system 228 # assert torch.get_num_threads() == max(1, init_threads - 3) 229 # AssertionError: assert 23 == 21 230 "test_auto_num_threads" 231 232 # Flaky (hangs indefinitely on some CPUs) 233 "test_gae_multidim" 234 "test_gae_param_as_tensor" 235 ] 236 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ 237 # Flaky 238 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2) 239 "test_vecnorm_parallel_auto" 240 ]; 241 242 disabledTestPaths = [ 243 # ERROR collecting test/smoke_test.py 244 # import file mismatch: 245 # imported module 'smoke_test' has this __file__ attribute: 246 # /build/source/test/llm/smoke_test.py 247 # which is not the same as the test file we want to collect: 248 # /build/source/test/smoke_test.py 249 "test/llm" 250 ]; 251 252 meta = { 253 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; 254 homepage = "https://github.com/pytorch/rl"; 255 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}"; 256 license = lib.licenses.mit; 257 maintainers = with lib.maintainers; [ GaetanLepage ]; 258 }; 259}