1{
2 lib,
3 stdenv,
4 absl-py,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # build-system
9 flit-core,
10
11 # dependencies
12 etils,
13 humanize,
14 importlib-resources,
15 jax,
16 msgpack,
17 nest-asyncio,
18 numpy,
19 protobuf,
20 pyyaml,
21 simplejson,
22 tensorstore,
23 typing-extensions,
24
25 # tests
26 aiofiles,
27 chex,
28 google-cloud-logging,
29 mock,
30 optax,
31 portpicker,
32 pytest-xdist,
33 pytestCheckHook,
34}:
35
36buildPythonPackage rec {
37 pname = "orbax-checkpoint";
38 version = "0.11.13";
39 pyproject = true;
40
41 src = fetchFromGitHub {
42 owner = "google";
43 repo = "orbax";
44 tag = "v${version}";
45 hash = "sha256-qmq0Kz8wXUFFE4CqsdFwKXAIvysFbv7JomQSrNj1QCc=";
46 };
47
48 sourceRoot = "${src.name}/checkpoint";
49
50 build-system = [ flit-core ];
51
52 pythonRelaxDeps = [
53 "jax"
54 ];
55
56 dependencies = [
57 absl-py
58 etils
59 humanize
60 importlib-resources
61 jax
62 msgpack
63 nest-asyncio
64 numpy
65 protobuf
66 pyyaml
67 simplejson
68 tensorstore
69 typing-extensions
70 ];
71
72 nativeCheckInputs = [
73 aiofiles
74 chex
75 google-cloud-logging
76 mock
77 optax
78 portpicker
79 pytest-xdist
80 pytestCheckHook
81 ];
82
83 pythonImportsCheck = [
84 "orbax"
85 "orbax.checkpoint"
86 ];
87
88 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
89 # Probably failing because of a filesystem impurity
90 # self.assertFalse(os.path.exists(dst_dir))
91 # AssertionError: True is not false
92 "test_create_snapshot"
93 ];
94
95 disabledTestPaths = [
96 # E absl.flags._exceptions.DuplicateFlagError: The flag 'num_processes' is defined twice.
97 # First from multiprocess_test, Second from orbax.checkpoint._src.testing.multiprocess_test.
98 # Description from first occurrence: Number of processes to use.
99 # https://github.com/google/orbax/issues/1580
100 "orbax/checkpoint/experimental/emergency/"
101
102 # Circular dependency flax
103 "orbax/checkpoint/_src/metadata/empty_values_test.py"
104 "orbax/checkpoint/_src/metadata/tree_rich_types_test.py"
105 "orbax/checkpoint/_src/metadata/tree_test.py"
106 "orbax/checkpoint/_src/testing/test_tree_utils.py"
107 "orbax/checkpoint/_src/tree/parts_of_test.py"
108 "orbax/checkpoint/_src/tree/utils_test.py"
109 "orbax/checkpoint/single_host_test.py"
110 "orbax/checkpoint/transform_utils_test.py"
111 ];
112
113 meta = {
114 description = "Orbax provides common utility libraries for JAX users";
115 homepage = "https://github.com/google/orbax/tree/main/checkpoint";
116 changelog = "https://github.com/google/orbax/blob/v${version}/checkpoint/CHANGELOG.md";
117 license = lib.licenses.asl20;
118 maintainers = with lib.maintainers; [ fab ];
119 };
120}