at 25.11-pre 2.4 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 setuptools-scm, 9 10 # dependencies 11 jax, 12 msgpack, 13 numpy, 14 optax, 15 orbax-checkpoint, 16 pyyaml, 17 rich, 18 tensorstore, 19 typing-extensions, 20 21 # optional-dependencies 22 matplotlib, 23 24 # tests 25 cloudpickle, 26 keras, 27 einops, 28 flaxlib, 29 pytestCheckHook, 30 pytest-xdist, 31 sphinx, 32 tensorflow, 33 treescope, 34 35 writeScript, 36 tomlq, 37}: 38 39buildPythonPackage rec { 40 pname = "flax"; 41 version = "0.10.6"; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "google"; 46 repo = "flax"; 47 tag = "v${version}"; 48 hash = "sha256-HhepJp7y2YN05XcZhB/L08g+yOfTJPRzd2m4ALQJGvw="; 49 }; 50 51 build-system = [ 52 setuptools 53 setuptools-scm 54 ]; 55 56 dependencies = [ 57 flaxlib 58 jax 59 msgpack 60 numpy 61 optax 62 orbax-checkpoint 63 pyyaml 64 rich 65 tensorstore 66 treescope 67 typing-extensions 68 ]; 69 70 optional-dependencies = { 71 all = [ matplotlib ]; 72 }; 73 74 pythonImportsCheck = [ "flax" ]; 75 76 nativeCheckInputs = [ 77 cloudpickle 78 keras 79 einops 80 pytestCheckHook 81 pytest-xdist 82 sphinx 83 tensorflow 84 ]; 85 86 disabledTestPaths = [ 87 # Docs test, needs extra deps + we're not interested in it. 88 "docs/_ext/codediff_test.py" 89 90 # The tests in `examples` are not designed to be executed from a single test 91 # session and thus either have the modules that conflict with each other or 92 # wrong import paths, depending on how they're invoked. Many tests also have 93 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 94 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 95 # would be limited anyway. 96 "examples/*" 97 ]; 98 99 disabledTests = [ 100 # AssertionError: [Chex] Function 'add' is traced > 1 times! 101 "PadShardUnpadTest" 102 ]; 103 104 passthru = { 105 updateScript = writeScript "update.sh" '' 106 nix-update flax # does not --build by default 107 nix-build . -A flax.src # src is essentially a passthru 108 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit 109 ''; 110 }; 111 112 meta = { 113 description = "Neural network library for JAX"; 114 homepage = "https://github.com/google/flax"; 115 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 116 license = lib.licenses.asl20; 117 maintainers = with lib.maintainers; [ ndl ]; 118 }; 119}