1{ 2 lib, 3 stdenv, 4 absl-py, 5 buildPythonPackage, 6 fetchFromGitHub, 7 8 # build-system 9 flit-core, 10 11 # dependencies 12 etils, 13 humanize, 14 importlib-resources, 15 jax, 16 msgpack, 17 nest-asyncio, 18 numpy, 19 protobuf, 20 pyyaml, 21 simplejson, 22 tensorstore, 23 typing-extensions, 24 25 # tests 26 aiofiles, 27 chex, 28 google-cloud-logging, 29 mock, 30 optax, 31 portpicker, 32 pytest-xdist, 33 pytestCheckHook, 34}: 35 36buildPythonPackage rec { 37 pname = "orbax-checkpoint"; 38 version = "0.11.13"; 39 pyproject = true; 40 41 src = fetchFromGitHub { 42 owner = "google"; 43 repo = "orbax"; 44 tag = "v${version}"; 45 hash = "sha256-qmq0Kz8wXUFFE4CqsdFwKXAIvysFbv7JomQSrNj1QCc="; 46 }; 47 48 sourceRoot = "${src.name}/checkpoint"; 49 50 build-system = [ flit-core ]; 51 52 pythonRelaxDeps = [ 53 "jax" 54 ]; 55 56 dependencies = [ 57 absl-py 58 etils 59 humanize 60 importlib-resources 61 jax 62 msgpack 63 nest-asyncio 64 numpy 65 protobuf 66 pyyaml 67 simplejson 68 tensorstore 69 typing-extensions 70 ]; 71 72 nativeCheckInputs = [ 73 aiofiles 74 chex 75 google-cloud-logging 76 mock 77 optax 78 portpicker 79 pytest-xdist 80 pytestCheckHook 81 ]; 82 83 pythonImportsCheck = [ 84 "orbax" 85 "orbax.checkpoint" 86 ]; 87 88 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [ 89 # Probably failing because of a filesystem impurity 90 # self.assertFalse(os.path.exists(dst_dir)) 91 # AssertionError: True is not false 92 "test_create_snapshot" 93 ]; 94 95 disabledTestPaths = [ 96 # E absl.flags._exceptions.DuplicateFlagError: The flag 'num_processes' is defined twice. 97 # First from multiprocess_test, Second from orbax.checkpoint._src.testing.multiprocess_test. 98 # Description from first occurrence: Number of processes to use. 99 # https://github.com/google/orbax/issues/1580 100 "orbax/checkpoint/experimental/emergency/" 101 102 # Circular dependency flax 103 "orbax/checkpoint/_src/metadata/empty_values_test.py" 104 "orbax/checkpoint/_src/metadata/tree_rich_types_test.py" 105 "orbax/checkpoint/_src/metadata/tree_test.py" 106 "orbax/checkpoint/_src/testing/test_tree_utils.py" 107 "orbax/checkpoint/_src/tree/parts_of_test.py" 108 "orbax/checkpoint/_src/tree/utils_test.py" 109 "orbax/checkpoint/single_host_test.py" 110 "orbax/checkpoint/transform_utils_test.py" 111 ]; 112 113 meta = { 114 description = "Orbax provides common utility libraries for JAX users"; 115 homepage = "https://github.com/google/orbax/tree/main/checkpoint"; 116 changelog = "https://github.com/google/orbax/blob/v${version}/checkpoint/CHANGELOG.md"; 117 license = lib.licenses.asl20; 118 maintainers = with lib.maintainers; [ fab ]; 119 }; 120}