1{
2 lib,
3 stdenv,
4 absl-py,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # build-system
9 flit-core,
10
11 # dependencies
12 etils,
13 humanize,
14 importlib-resources,
15 jax,
16 jaxlib,
17 msgpack,
18 nest-asyncio,
19 numpy,
20 protobuf,
21 pyyaml,
22 tensorstore,
23 typing-extensions,
24
25 # tests
26 chex,
27 google-cloud-logging,
28 mock,
29 pytest-xdist,
30 pytestCheckHook,
31}:
32
33buildPythonPackage rec {
34 pname = "orbax-checkpoint";
35 version = "0.6.4";
36 pyproject = true;
37
38 src = fetchFromGitHub {
39 owner = "google";
40 repo = "orbax";
41 rev = "refs/tags/v${version}";
42 hash = "sha256-xd75/AKBFUdA6a8sQnCB2rVbHl/Foy4LTb07jnwrTjA=";
43 };
44
45 sourceRoot = "${src.name}/checkpoint";
46
47 build-system = [ flit-core ];
48
49 dependencies = [
50 absl-py
51 etils
52 humanize
53 importlib-resources
54 jax
55 jaxlib
56 msgpack
57 nest-asyncio
58 numpy
59 protobuf
60 pyyaml
61 tensorstore
62 typing-extensions
63 ];
64
65 nativeCheckInputs = [
66 chex
67 google-cloud-logging
68 mock
69 pytest-xdist
70 pytestCheckHook
71 ];
72
73 pythonImportsCheck = [
74 "orbax"
75 "orbax.checkpoint"
76 ];
77
78 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
79 # Probably failing because of a filesystem impurity
80 # self.assertFalse(os.path.exists(dst_dir))
81 # AssertionError: True is not false
82 "test_create_snapshot"
83 ];
84
85 disabledTestPaths = [
86 # Circular dependency flax
87 "orbax/checkpoint/transform_utils_test.py"
88 "orbax/checkpoint/utils_test.py"
89 ];
90
91 meta = {
92 description = "Orbax provides common utility libraries for JAX users";
93 homepage = "https://github.com/google/orbax/tree/main/checkpoint";
94 changelog = "https://github.com/google/orbax/releases/tag/v${version}";
95 license = lib.licenses.asl20;
96 maintainers = with lib.maintainers; [ fab ];
97 };
98}