Merge pull request #275195 from GaetanLepage/torchrl

python311Packages.torchrl: init at 0.2.1

authored by Someone and committed by GitHub b63791c6 39b0be0d

+256 -5
+18 -5
pkgs/development/python-modules/ale-py/default.nix
··· 2 2 , SDL2 3 3 , cmake 4 4 , fetchFromGitHub 5 - , git 5 + , fetchpatch 6 6 , gym 7 7 , importlib-metadata 8 8 , importlib-resources ··· 11 11 , numpy 12 12 , pybind11 13 13 , pytestCheckHook 14 - , python 15 14 , pythonOlder 16 15 , setuptools 17 16 , stdenv ··· 23 22 buildPythonPackage rec { 24 23 pname = "ale-py"; 25 24 version = "0.8.1"; 26 - format = "pyproject"; 25 + pyproject = true; 27 26 28 27 src = fetchFromGitHub { 29 - owner = "mgbellemare"; 28 + owner = "Farama-Foundation"; 30 29 repo = "Arcade-Learning-Environment"; 31 30 rev = "refs/tags/v${version}"; 32 31 hash = "sha256-B2AxhlzvBy1lJ3JttJjImgTjMtEUyZBv+xHU2IC7BVE="; ··· 35 34 patches = [ 36 35 # don't download pybind11, use local pybind11 37 36 ./cmake-pybind11.patch 37 + ./patch-sha-check-in-setup.patch 38 + 39 + # The following two patches add the required `include <cstdint>` for compilation to work with GCC 13. 40 + # See https://github.com/Farama-Foundation/Arcade-Learning-Environment/pull/503 41 + (fetchpatch { 42 + name = "fix-gcc13-compilation-1"; 43 + url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/ebd64c03cdaa3d8df7da7c62ec3ae5795105e27a.patch"; 44 + hash = "sha256-NMz0hw8USOj88WryHRkMQNWznnP6+5aWovEYNuocQ2c="; 45 + }) 46 + (fetchpatch { 47 + name = "fix-gcc13-compilation-2"; 48 + url = "https://github.com/Farama-Foundation/Arcade-Learning-Environment/commit/4c99c7034f17810f3ff6c27436bfc3b40d08da21.patch"; 49 + hash = "sha256-66/bDCyMr1RsKk63T9GnFZGloLlkdr/bf5WHtWbX6VY="; 50 + }) 38 51 ]; 39 52 40 53 nativeBuildInputs = [ ··· 67 80 substituteInPlace pyproject.toml \ 68 81 --replace 'dynamic = ["version"]' 'version = "${version}"' 69 82 substituteInPlace setup.py \ 70 - --replace 'subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here)' 'b"${src.rev}"' 83 + --replace '@sha@' '"${version}"' 71 84 ''; 72 85 73 86 dontUseCmakeConfigure = true;
+17
pkgs/development/python-modules/ale-py/patch-sha-check-in-setup.patch
··· 1 + diff --git a/setup.py b/setup.py 2 + index ff1b1c5..ce40df0 100644 3 + --- a/setup.py 4 + +++ b/setup.py 5 + @@ -141,11 +141,7 @@ def parse_version(version_file): 6 + 7 + version = ci_version 8 + else: 9 + - sha = ( 10 + - subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], cwd=here) 11 + - .decode("ascii") 12 + - .strip() 13 + - ) 14 + + sha = @sha@ 15 + version += f"+{sha}" 16 + 17 + return version
+63
pkgs/development/python-modules/tensordict/default.nix
··· 1 + { lib 2 + , buildPythonPackage 3 + , pythonOlder 4 + , fetchFromGitHub 5 + , setuptools 6 + , torch 7 + , wheel 8 + , which 9 + , cloudpickle 10 + , numpy 11 + , h5py 12 + , pytestCheckHook 13 + }: 14 + 15 + buildPythonPackage rec { 16 + pname = "tensordict"; 17 + version = "0.2.1"; 18 + pyproject = true; 19 + 20 + disabled = pythonOlder "3.8"; 21 + 22 + src = fetchFromGitHub { 23 + owner = "pytorch"; 24 + repo = "tensordict"; 25 + rev = "refs/tags/v${version}"; 26 + hash = "sha256-+Osoz1632F/dEkG/o8RUqCIDok2Qc9Qdak+CCr9m26g="; 27 + }; 28 + 29 + nativeBuildInputs = [ 30 + setuptools 31 + torch 32 + wheel 33 + which 34 + ]; 35 + 36 + propagatedBuildInputs = [ 37 + cloudpickle 38 + numpy 39 + torch 40 + ]; 41 + 42 + pythonImportsCheck = [ 43 + "tensordict" 44 + ]; 45 + 46 + # We have to delete the source because otherwise it is used instead of the installed package. 47 + preCheck = '' 48 + rm -rf tensordict 49 + ''; 50 + 51 + nativeCheckInputs = [ 52 + h5py 53 + pytestCheckHook 54 + ]; 55 + 56 + meta = with lib; { 57 + description = "A pytorch dedicated tensor container"; 58 + changelog = "https://github.com/pytorch/tensordict/releases/tag/v${version}"; 59 + homepage = "https://github.com/pytorch/tensordict"; 60 + license = licenses.mit; 61 + maintainers = with maintainers; [ GaetanLepage ]; 62 + }; 63 + }
+154
pkgs/development/python-modules/torchrl/default.nix
··· 1 + { lib 2 + , buildPythonPackage 3 + , pythonOlder 4 + , fetchFromGitHub 5 + , fetchpatch 6 + , ninja 7 + , setuptools 8 + , wheel 9 + , which 10 + , cloudpickle 11 + , numpy 12 + , torch 13 + , ale-py 14 + , gym 15 + , pygame 16 + , gymnasium 17 + , mujoco 18 + , moviepy 19 + , git 20 + , hydra-core 21 + , tensorboard 22 + , tqdm 23 + , wandb 24 + , packaging 25 + , tensordict 26 + , imageio 27 + , pytest-rerunfailures 28 + , pytestCheckHook 29 + , pyyaml 30 + , scipy 31 + }: 32 + 33 + buildPythonPackage rec { 34 + pname = "torchrl"; 35 + version = "0.2.1"; 36 + pyproject = true; 37 + 38 + disabled = pythonOlder "3.8"; 39 + 40 + src = fetchFromGitHub { 41 + owner = "pytorch"; 42 + repo = "rl"; 43 + rev = "refs/tags/v${version}"; 44 + hash = "sha256-Y3WbSMGXS6fb4RyXk2SAKHT6RencGTZXM3tc65AQx74="; 45 + }; 46 + 47 + patches = [ 48 + (fetchpatch { # https://github.com/pytorch/rl/pull/1828 49 + name = "pyproject.toml-remove-unknown-properties"; 50 + url = "https://github.com/pytorch/rl/commit/c390cf602fc79cb37d5f7bda6e44b5e9546ecda0.patch"; 51 + hash = "sha256-cUBBvKJ8vIHprcGzMojkUxcOrrmNPIoIBfLwHXWkjOc="; 52 + }) 53 + ]; 54 + 55 + nativeBuildInputs = [ 56 + ninja 57 + setuptools 58 + wheel 59 + which 60 + ]; 61 + 62 + propagatedBuildInputs = [ 63 + cloudpickle 64 + numpy 65 + packaging 66 + tensordict 67 + torch 68 + ]; 69 + 70 + passthru.optional-dependencies = { 71 + atari = [ 72 + ale-py 73 + gym 74 + pygame 75 + ]; 76 + gym-continuous = [ 77 + gymnasium 78 + mujoco 79 + ]; 80 + rendering = [ 81 + moviepy 82 + ]; 83 + utils = [ 84 + git 85 + hydra-core 86 + tensorboard 87 + tqdm 88 + wandb 89 + ]; 90 + }; 91 + 92 + # torchrl needs to create a folder to store datasets 93 + preBuild = '' 94 + export D4RL_DATASET_DIR=$(mktemp -d) 95 + ''; 96 + 97 + pythonImportsCheck = [ 98 + "torchrl" 99 + ]; 100 + 101 + # We have to delete the source because otherwise it is used instead of the installed package. 102 + preCheck = '' 103 + rm -rf torchrl 104 + 105 + export XDG_RUNTIME_DIR=$(mktemp -d) 106 + '' 107 + # Otherwise, tochrl will try to use unpackaged torchsnapshot. 108 + # TODO: This should be the default from next release so remove when updating from 0.2.1 109 + + '' 110 + export CKPT_BACKEND="torch" 111 + ''; 112 + 113 + nativeCheckInputs = [ 114 + gymnasium 115 + imageio 116 + pytest-rerunfailures 117 + pytestCheckHook 118 + pyyaml 119 + scipy 120 + ] 121 + ++ passthru.optional-dependencies.atari 122 + ++ passthru.optional-dependencies.gym-continuous 123 + ++ passthru.optional-dependencies.rendering; 124 + 125 + disabledTests = [ 126 + # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called 127 + "test_vecenvs_env" 128 + 129 + # ValueError: Can't write images with one color channel. 130 + "test_log_video" 131 + 132 + # Those tests require the ALE environments (provided by unpackaged shimmy) 133 + "test_collector_env_reset" 134 + "test_gym" 135 + "test_gym_fake_td" 136 + "test_recorder" 137 + "test_recorder_load" 138 + "test_rollout" 139 + "test_parallel_trans_env_check" 140 + "test_serial_trans_env_check" 141 + "test_single_trans_env_check" 142 + "test_td_creation_from_spec" 143 + "test_trans_parallel_env_check" 144 + "test_trans_serial_env_check" 145 + "test_transform_env" 146 + ]; 147 + 148 + meta = with lib; { 149 + description = "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; 150 + homepage = "https://github.com/pytorch/rl"; 151 + license = licenses.mit; 152 + maintainers = with maintainers; [ GaetanLepage ]; 153 + }; 154 + }
+4
pkgs/top-level/python-packages.nix
··· 14144 14144 14145 14145 tensorboardx = callPackage ../development/python-modules/tensorboardx { }; 14146 14146 14147 + tensordict = callPackage ../development/python-modules/tensordict { }; 14148 + 14147 14149 tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix { 14148 14150 inherit (pkgs.config) cudaSupport; 14149 14151 }; ··· 14547 14549 torchinfo = callPackage ../development/python-modules/torchinfo { }; 14548 14550 14549 14551 torchlibrosa = callPackage ../development/python-modules/torchlibrosa { }; 14552 + 14553 + torchrl = callPackage ../development/python-modules/torchrl { }; 14550 14554 14551 14555 torchsde = callPackage ../development/python-modules/torchsde { }; 14552 14556