nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 89 lines 1.8 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 flit-core, 8 9 # dependencies 10 absl-py, 11 chex, 12 distrax, 13 dm-env, 14 jax, 15 jaxlib, 16 numpy, 17 18 # tests 19 dm-haiku, 20 optax, 21 pytest-xdist, 22 pytestCheckHook, 23}: 24 25buildPythonPackage rec { 26 pname = "rlax"; 27 version = "0.1.8"; 28 pyproject = true; 29 30 src = fetchFromGitHub { 31 owner = "google-deepmind"; 32 repo = "rlax"; 33 tag = "v${version}"; 34 hash = "sha256-E/zYFd5bfx58FfA3uR7hzRAIs844QzJA8TZTwmwDByk="; 35 }; 36 37 build-system = [ 38 flit-core 39 ]; 40 41 dependencies = [ 42 absl-py 43 chex 44 distrax 45 dm-env 46 jax 47 jaxlib 48 numpy 49 ]; 50 51 nativeCheckInputs = [ 52 dm-haiku 53 optax 54 pytest-xdist 55 pytestCheckHook 56 ]; 57 58 pythonImportsCheck = [ "rlax" ]; 59 60 disabledTests = [ 61 # AssertionError: Array(2, dtype=int32) != 0 62 "test_categorical_sample__with_device" 63 "test_categorical_sample__with_jit" 64 "test_categorical_sample__without_device" 65 "test_categorical_sample__without_jit" 66 67 # RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: 68 # ensure that `set_n_cpu_devices` is executed before any JAX operation. 69 "test_cross_replica_scatter_add0" 70 "test_cross_replica_scatter_add1" 71 "test_cross_replica_scatter_add2" 72 "test_cross_replica_scatter_add3" 73 "test_cross_replica_scatter_add4" 74 "test_learn_scale_shift" 75 "test_normalize_unnormalize_is_identity" 76 "test_outputs_preserved" 77 "test_scale_bounded" 78 "test_slow_update" 79 "test_unnormalize_linear" 80 ]; 81 82 meta = { 83 description = "Library of reinforcement learning building blocks in JAX"; 84 homepage = "https://github.com/deepmind/rlax"; 85 changelog = "https://github.com/google-deepmind/rlax/releases/tag/${src.tag}"; 86 license = lib.licenses.asl20; 87 maintainers = with lib.maintainers; [ onny ]; 88 }; 89}