nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 76 lines 1.5 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 flit-core, 8 9 # dependencies 10 absl-py, 11 jax, 12 jaxlib, 13 numpy, 14 toolz, 15 typing-extensions, 16 17 # tests 18 cloudpickle, 19 dm-tree, 20 pytestCheckHook, 21}: 22 23buildPythonPackage rec { 24 pname = "chex"; 25 version = "0.1.91"; 26 pyproject = true; 27 28 src = fetchFromGitHub { 29 owner = "deepmind"; 30 repo = "chex"; 31 tag = "v${version}"; 32 hash = "sha256-lJ9+kvG7dRtfDVgvkcJ9/jtnX0lMfxY4mmZ290y/74U="; 33 }; 34 35 build-system = [ 36 flit-core 37 ]; 38 39 pythonRelaxDeps = [ 40 "typing_extensions" 41 ]; 42 dependencies = [ 43 absl-py 44 jax 45 jaxlib 46 numpy 47 toolz 48 typing-extensions 49 ]; 50 51 pythonImportsCheck = [ "chex" ]; 52 53 nativeCheckInputs = [ 54 cloudpickle 55 dm-tree 56 pytestCheckHook 57 ]; 58 59 disabledTests = [ 60 # Jax 0.8.2 incompatibility (reported at https://github.com/google-deepmind/chex/issues/422) 61 # AssertionError: AssertionError not raised 62 "test_assert_tree_is_on_device" 63 # AssertionError: "\[Chex\]\ [\s\S]*sharded arrays are disallowed" does not match ... 64 "test_assert_tree_is_on_host" 65 # AssertionError: [Chex] Assertion assert_tree_is_sharded failed: ... 66 "test_assert_tree_is_sharded" 67 ]; 68 69 meta = { 70 description = "Library of utilities for helping to write reliable JAX code"; 71 homepage = "https://github.com/deepmind/chex"; 72 changelog = "https://github.com/google-deepmind/chex/releases/tag/v${version}"; 73 license = lib.licenses.asl20; 74 maintainers = with lib.maintainers; [ ndl ]; 75 }; 76}