Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 absl-py, 7 chex, 8 distrax, 9 dm-env, 10 jax, 11 jaxlib, 12 numpy, 13 tensorflow-probability, 14 dm-haiku, 15 optax, 16 pytest-xdist, 17 pytestCheckHook, 18}: 19 20buildPythonPackage rec { 21 pname = "rlax"; 22 version = "0.1.6"; 23 format = "setuptools"; 24 25 src = fetchFromGitHub { 26 owner = "google-deepmind"; 27 repo = "rlax"; 28 rev = "refs/tags/v${version}"; 29 hash = "sha256-v2Lbzya+E9d7tlUVlQQa4fuPp2q3E309Qvyt70mcdb0="; 30 }; 31 32 patches = [ 33 (fetchpatch { 34 # Follow chex API change (https://github.com/google-deepmind/chex/pull/52) 35 name = "replace-deprecated-chex-assertions"; 36 url = "https://github.com/google-deepmind/rlax/commit/30e7913a1102667137654d6e652a6c4b9e9ba1f4.patch"; 37 hash = "sha256-OPnuTKEtwZ28hzR1660v3DcktxTYjhR1xYvFbQvOhgs="; 38 }) 39 ]; 40 41 propagatedBuildInputs = [ 42 absl-py 43 chex 44 distrax 45 dm-env 46 jax 47 jaxlib 48 numpy 49 tensorflow-probability 50 ]; 51 52 nativeCheckInputs = [ 53 dm-haiku 54 optax 55 pytest-xdist 56 pytestCheckHook 57 ]; 58 59 pythonImportsCheck = [ "rlax" ]; 60 61 disabledTests = [ 62 # RuntimeErrors 63 "test_cross_replica_scatter_add0" 64 "test_cross_replica_scatter_add1" 65 "test_cross_replica_scatter_add2" 66 "test_cross_replica_scatter_add3" 67 "test_cross_replica_scatter_add4" 68 "test_learn_scale_shift" 69 "test_normalize_unnormalize_is_identity" 70 "test_outputs_preserved" 71 "test_scale_bounded" 72 "test_slow_update" 73 "test_unnormalize_linear" 74 ]; 75 76 meta = with lib; { 77 description = "Library of reinforcement learning building blocks in JAX"; 78 homepage = "https://github.com/deepmind/rlax"; 79 license = licenses.asl20; 80 maintainers = with maintainers; [ onny ]; 81 }; 82}