at 25.11-pre 4.5 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 ninja, 9 setuptools, 10 which, 11 12 # dependencies 13 cloudpickle, 14 numpy, 15 packaging, 16 tensordict, 17 torch, 18 19 # optional-dependencies 20 ale-py, 21 gym, 22 pygame, 23 torchsnapshot, 24 gymnasium, 25 mujoco, 26 h5py, 27 huggingface-hub, 28 minari, 29 pandas, 30 pillow, 31 requests, 32 scikit-learn, 33 torchvision, 34 tqdm, 35 moviepy, 36 git, 37 hydra-core, 38 tensorboard, 39 wandb, 40 41 # tests 42 imageio, 43 pytest-rerunfailures, 44 pytestCheckHook, 45 pyyaml, 46 scipy, 47}: 48 49buildPythonPackage rec { 50 pname = "torchrl"; 51 version = "0.8.0"; 52 pyproject = true; 53 54 src = fetchFromGitHub { 55 owner = "pytorch"; 56 repo = "rl"; 57 tag = "v${version}"; 58 hash = "sha256-icT+QeA2FNhZjwD0ykui4aq5WswDv2i1QRh7dNlA4Cg="; 59 }; 60 61 build-system = [ 62 ninja 63 setuptools 64 which 65 ]; 66 67 dependencies = [ 68 cloudpickle 69 numpy 70 packaging 71 tensordict 72 torch 73 ]; 74 75 optional-dependencies = { 76 atari = [ 77 ale-py 78 gym 79 pygame 80 ]; 81 checkpointing = [ torchsnapshot ]; 82 gym-continuous = [ 83 gymnasium 84 mujoco 85 ]; 86 offline-data = [ 87 h5py 88 huggingface-hub 89 minari 90 pandas 91 pillow 92 requests 93 scikit-learn 94 torchvision 95 tqdm 96 ]; 97 rendering = [ moviepy ]; 98 utils = [ 99 git 100 hydra-core 101 tensorboard 102 tqdm 103 wandb 104 ]; 105 }; 106 107 # torchrl needs to create a folder to store datasets 108 preBuild = '' 109 export D4RL_DATASET_DIR=$(mktemp -d) 110 ''; 111 112 pythonImportsCheck = [ "torchrl" ]; 113 114 # We have to delete the source because otherwise it is used instead of the installed package. 115 preCheck = '' 116 rm -rf torchrl 117 118 export XDG_RUNTIME_DIR=$(mktemp -d) 119 ''; 120 121 nativeCheckInputs = 122 [ 123 h5py 124 gymnasium 125 imageio 126 pytest-rerunfailures 127 pytestCheckHook 128 pyyaml 129 scipy 130 torchvision 131 ] 132 ++ optional-dependencies.atari 133 ++ optional-dependencies.gym-continuous 134 ++ optional-dependencies.rendering; 135 136 disabledTests = 137 [ 138 # torchrl is incompatible with gymnasium>=1.0 139 # https://github.com/pytorch/rl/discussions/2483 140 "test_resetting_strategies" 141 "test_torchrl_to_gym" 142 "test_vecenvs_nan" 143 144 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist. 145 "test_collector_run" 146 "test_transform_inverse" 147 148 # OSError: Unable to synchronously create file (unable to truncate a file which is already open) 149 "test_multi_env" 150 "test_simple_env" 151 152 # ImportWarning: Ignoring non-library in plugin directory: 153 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2 154 "test_auto_register" 155 "test_info_dict_reader" 156 157 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called 158 "test_vecenvs_env" 159 160 # ValueError: Can't write images with one color channel. 161 "test_log_video" 162 163 # Those tests require the ALE environments (provided by unpackaged shimmy) 164 "test_collector_env_reset" 165 "test_gym" 166 "test_gym_fake_td" 167 "test_recorder" 168 "test_recorder_load" 169 "test_rollout" 170 "test_parallel_trans_env_check" 171 "test_serial_trans_env_check" 172 "test_single_trans_env_check" 173 "test_td_creation_from_spec" 174 "test_trans_parallel_env_check" 175 "test_trans_serial_env_check" 176 "test_transform_env" 177 178 # undeterministic 179 "test_distributed_collector_updatepolicy" 180 "test_timeit" 181 182 # On a 24 threads system 183 # assert torch.get_num_threads() == max(1, init_threads - 3) 184 # AssertionError: assert 23 == 21 185 "test_auto_num_threads" 186 187 # Flaky (hangs indefinitely on some CPUs) 188 "test_gae_multidim" 189 "test_gae_param_as_tensor" 190 ] 191 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ 192 # Flaky 193 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2) 194 "test_vecnorm_parallel_auto" 195 ]; 196 197 meta = { 198 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; 199 homepage = "https://github.com/pytorch/rl"; 200 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}"; 201 license = lib.licenses.mit; 202 maintainers = with lib.maintainers; [ GaetanLepage ]; 203 }; 204}