Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at r-updates 3.1 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 absl-py, 12 chex, 13 distrax, 14 dm-env, 15 jax, 16 jaxlib, 17 numpy, 18 tensorflow-probability, 19 20 # tests 21 dm-haiku, 22 optax, 23 pytest-xdist, 24 pytestCheckHook, 25}: 26 27buildPythonPackage rec { 28 pname = "rlax"; 29 version = "0.1.7"; 30 pyproject = true; 31 32 src = fetchFromGitHub { 33 owner = "google-deepmind"; 34 repo = "rlax"; 35 tag = "v${version}"; 36 hash = "sha256-w5vhXBMUlcqlLTKA58QgQ4pxyGs3etxJLIFUVPhE7H8="; 37 }; 38 39 # TODO: remove these patches at the next release (already on master) 40 patches = [ 41 (fetchpatch { 42 # Follow chex API change (https://github.com/google-deepmind/chex/pull/52) 43 name = "replace-deprecated-chex-assertions"; 44 url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch"; 45 hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs="; 46 }) 47 (fetchpatch { 48 name = "remove-deprecation-warning"; 49 url = "https://github.com/google-deepmind/rlax/commit/dea6eb479ffc32156aefe73015387a762c6b4562.patch"; 50 hash = "sha256-htDyDRJW0eQx7AmrS3Fl7Lbh2VAmoYiDgHSePsQUaWs="; 51 }) 52 (fetchpatch { 53 name = "fix-deprecation-warnings"; 54 url = "https://github.com/google-deepmind/rlax/commit/605e0ef8ad8f9a06e88d4aabbb7d50e086d0cf3a.patch"; 55 hash = "sha256-GZ/nGMXne6Lv6yDm/29NVTWxLBVSzaPYKAfQOLHY4UI="; 56 }) 57 # https://github.com/google-deepmind/rlax/pull/135 58 (fetchpatch { 59 name = "fix-jax-0.6.0-compat"; 60 url = "https://github.com/google-deepmind/rlax/commit/461b4cf9b4239d6b1b83aad6e5946f68d8402b93.patch"; 61 hash = "sha256-uPMpm4IcoBWJwnyuIRjQEfo0F9HIW/lrwecxGW/Yw38="; 62 }) 63 ]; 64 65 build-system = [ 66 setuptools 67 ]; 68 69 dependencies = [ 70 absl-py 71 chex 72 distrax 73 dm-env 74 jax 75 jaxlib 76 numpy 77 tensorflow-probability 78 ]; 79 80 nativeCheckInputs = [ 81 dm-haiku 82 optax 83 pytest-xdist 84 pytestCheckHook 85 ]; 86 87 pythonImportsCheck = [ "rlax" ]; 88 89 disabledTests = [ 90 # AssertionError: Array(2, dtype=int32) != 0 91 "test_categorical_sample__with_device" 92 "test_categorical_sample__with_jit" 93 "test_categorical_sample__without_device" 94 "test_categorical_sample__without_jit" 95 96 # RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: 97 # ensure that `set_n_cpu_devices` is executed before any JAX operation. 98 "test_cross_replica_scatter_add0" 99 "test_cross_replica_scatter_add1" 100 "test_cross_replica_scatter_add2" 101 "test_cross_replica_scatter_add3" 102 "test_cross_replica_scatter_add4" 103 "test_learn_scale_shift" 104 "test_normalize_unnormalize_is_identity" 105 "test_outputs_preserved" 106 "test_scale_bounded" 107 "test_slow_update" 108 "test_unnormalize_linear" 109 ]; 110 111 meta = { 112 description = "Library of reinforcement learning building blocks in JAX"; 113 homepage = "https://github.com/deepmind/rlax"; 114 changelog = "https://github.com/google-deepmind/rlax/releases/tag/${src.tag}"; 115 license = lib.licenses.asl20; 116 maintainers = with lib.maintainers; [ onny ]; 117 }; 118}