at 24.11-pre 2.3 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 cloudpickle, 5 einops, 6 fetchFromGitHub, 7 jax, 8 jaxlib, 9 keras, 10 matplotlib, 11 msgpack, 12 numpy, 13 optax, 14 orbax-checkpoint, 15 pytest-xdist, 16 pytestCheckHook, 17 pythonOlder, 18 pythonRelaxDepsHook, 19 pyyaml, 20 rich, 21 setuptools-scm, 22 tensorflow, 23 tensorstore, 24 typing-extensions, 25}: 26 27buildPythonPackage rec { 28 pname = "flax"; 29 version = "0.8.3"; 30 pyproject = true; 31 32 disabled = pythonOlder "3.9"; 33 34 src = fetchFromGitHub { 35 owner = "google"; 36 repo = "flax"; 37 rev = "refs/tags/v${version}"; 38 hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU="; 39 }; 40 41 build-system = [ 42 jaxlib 43 pythonRelaxDepsHook 44 setuptools-scm 45 ]; 46 47 dependencies = [ 48 jax 49 msgpack 50 numpy 51 optax 52 orbax-checkpoint 53 pyyaml 54 rich 55 tensorstore 56 typing-extensions 57 ]; 58 59 passthru.optional-dependencies = { 60 all = [ matplotlib ]; 61 }; 62 63 pythonImportsCheck = [ "flax" ]; 64 65 nativeCheckInputs = [ 66 cloudpickle 67 einops 68 keras 69 pytest-xdist 70 pytestCheckHook 71 tensorflow 72 ]; 73 74 pytestFlagsArray = [ 75 "-W ignore::FutureWarning" 76 "-W ignore::DeprecationWarning" 77 ]; 78 79 disabledTestPaths = [ 80 # Docs test, needs extra deps + we're not interested in it. 81 "docs/_ext/codediff_test.py" 82 # The tests in `examples` are not designed to be executed from a single test 83 # session and thus either have the modules that conflict with each other or 84 # wrong import paths, depending on how they're invoked. Many tests also have 85 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 86 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 87 # would be limited anyway. 88 "examples/*" 89 "flax/experimental/nnx/examples/*" 90 # See https://github.com/google/flax/issues/3232. 91 "tests/jax_utils_test.py" 92 # Requires tree 93 "tests/tensorboard_test.py" 94 ]; 95 96 disabledTests = [ 97 # ValueError: Checkpoint path should be absolute 98 "test_overwrite_checkpoints0" 99 ]; 100 101 meta = with lib; { 102 description = "Neural network library for JAX"; 103 homepage = "https://github.com/google/flax"; 104 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 105 license = licenses.asl20; 106 maintainers = with maintainers; [ ndl ]; 107 }; 108}