at 23.05-pre 1.2 kB view raw
1{ lib 2, fetchPypi 3, buildPythonPackage 4, chex 5, jaxlib 6, tensorflow-probability 7, optax 8, dm-haiku 9, bsuite 10, frozendict 11, pytestCheckHook 12, dm-env 13, distrax }: 14 15buildPythonPackage rec { 16 pname = "rlax"; 17 version = "0.1.4"; 18 19 src = fetchPypi { 20 inherit pname version; 21 sha256 = "sha256-a4qyJ5W9fs4TSTQQZS/NptlcSr2Nhw0pvnk+sGEsbyY="; 22 }; 23 24 buildInputs = [ 25 chex 26 jaxlib 27 distrax 28 tensorflow-probability 29 ]; 30 31 checkInputs = [ 32 bsuite 33 dm-env 34 dm-haiku 35 frozendict 36 optax 37 pytestCheckHook 38 ]; 39 40 pythonImportsCheck = [ 41 "rlax" 42 ]; 43 44 disabledTests = [ 45 # RuntimeErrors 46 "test_cross_replica_scatter_add0" 47 "test_cross_replica_scatter_add1" 48 "test_cross_replica_scatter_add2" 49 "test_cross_replica_scatter_add3" 50 "test_cross_replica_scatter_add4" 51 "test_learn_scale_shift" 52 "test_normalize_unnormalize_is_identity" 53 "test_outputs_preserved" 54 "test_scale_bounded" 55 "test_slow_update" 56 "test_unnormalize_linear" 57 ]; 58 59 meta = with lib; { 60 description = "Library of reinforcement learning building blocks in JAX"; 61 homepage = "https://github.com/deepmind/rlax"; 62 license = licenses.asl20; 63 maintainers = with maintainers; [ onny ]; 64 }; 65}