Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at flake-libs 126 lines 2.7 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 setuptools-scm, 9 10 # dependencies 11 jax, 12 msgpack, 13 numpy, 14 optax, 15 orbax-checkpoint, 16 pyyaml, 17 rich, 18 tensorstore, 19 typing-extensions, 20 21 # optional-dependencies 22 matplotlib, 23 24 # tests 25 cloudpickle, 26 keras, 27 einops, 28 flaxlib, 29 pytestCheckHook, 30 pytest-xdist, 31 sphinx, 32 tensorflow, 33 treescope, 34 35 writeScript, 36 tomlq, 37}: 38 39buildPythonPackage rec { 40 pname = "flax"; 41 version = "0.10.6"; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "google"; 46 repo = "flax"; 47 tag = "v${version}"; 48 hash = "sha256-HhepJp7y2YN05XcZhB/L08g+yOfTJPRzd2m4ALQJGvw="; 49 }; 50 51 build-system = [ 52 setuptools 53 setuptools-scm 54 ]; 55 56 dependencies = [ 57 flaxlib 58 jax 59 msgpack 60 numpy 61 optax 62 orbax-checkpoint 63 pyyaml 64 rich 65 tensorstore 66 treescope 67 typing-extensions 68 ]; 69 70 optional-dependencies = { 71 all = [ matplotlib ]; 72 }; 73 74 pythonImportsCheck = [ "flax" ]; 75 76 nativeCheckInputs = [ 77 cloudpickle 78 keras 79 einops 80 pytestCheckHook 81 pytest-xdist 82 sphinx 83 tensorflow 84 ]; 85 86 pytestFlagsArray = [ 87 "-W" 88 # DeprecationWarning: Triggering of __jax_array__() during abstractification is deprecated. 89 # To avoid this error, either explicitly convert your object using jax.numpy.array(), or register your object as a pytree. 90 "ignore::DeprecationWarning" 91 ]; 92 93 disabledTestPaths = [ 94 # Docs test, needs extra deps + we're not interested in it. 95 "docs/_ext/codediff_test.py" 96 97 # The tests in `examples` are not designed to be executed from a single test 98 # session and thus either have the modules that conflict with each other or 99 # wrong import paths, depending on how they're invoked. Many tests also have 100 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 101 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 102 # would be limited anyway. 103 "examples/*" 104 ]; 105 106 disabledTests = [ 107 # AssertionError: [Chex] Function 'add' is traced > 1 times! 108 "PadShardUnpadTest" 109 ]; 110 111 passthru = { 112 updateScript = writeScript "update.sh" '' 113 nix-update flax # does not --build by default 114 nix-build . -A flax.src # src is essentially a passthru 115 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit 116 ''; 117 }; 118 119 meta = { 120 description = "Neural network library for JAX"; 121 homepage = "https://github.com/google/flax"; 122 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 123 license = lib.licenses.asl20; 124 maintainers = with lib.maintainers; [ ndl ]; 125 }; 126}