nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 flit-core,
8
9 # dependencies
10 absl-py,
11 jax,
12 jaxlib,
13 numpy,
14 toolz,
15 typing-extensions,
16
17 # tests
18 cloudpickle,
19 dm-tree,
20 pytestCheckHook,
21}:
22
23buildPythonPackage rec {
24 pname = "chex";
25 version = "0.1.91";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "deepmind";
30 repo = "chex";
31 tag = "v${version}";
32 hash = "sha256-lJ9+kvG7dRtfDVgvkcJ9/jtnX0lMfxY4mmZ290y/74U=";
33 };
34
35 build-system = [
36 flit-core
37 ];
38
39 pythonRelaxDeps = [
40 "typing_extensions"
41 ];
42 dependencies = [
43 absl-py
44 jax
45 jaxlib
46 numpy
47 toolz
48 typing-extensions
49 ];
50
51 pythonImportsCheck = [ "chex" ];
52
53 nativeCheckInputs = [
54 cloudpickle
55 dm-tree
56 pytestCheckHook
57 ];
58
59 disabledTests = [
60 # Jax 0.8.2 incompatibility (reported at https://github.com/google-deepmind/chex/issues/422)
61 # AssertionError: AssertionError not raised
62 "test_assert_tree_is_on_device"
63 # AssertionError: "\[Chex\]\ [\s\S]*sharded arrays are disallowed" does not match ...
64 "test_assert_tree_is_on_host"
65 # AssertionError: [Chex] Assertion assert_tree_is_sharded failed: ...
66 "test_assert_tree_is_sharded"
67 ];
68
69 meta = {
70 description = "Library of utilities for helping to write reliable JAX code";
71 homepage = "https://github.com/deepmind/chex";
72 changelog = "https://github.com/google-deepmind/chex/releases/tag/v${version}";
73 license = lib.licenses.asl20;
74 maintainers = with lib.maintainers; [ ndl ];
75 };
76}