Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at gcc-offload 133 lines 3.1 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 jaxlib, 8 setuptools-scm, 9 10 # dependencies 11 jax, 12 msgpack, 13 numpy, 14 optax, 15 orbax-checkpoint, 16 pyyaml, 17 rich, 18 tensorstore, 19 typing-extensions, 20 21 # checks 22 cloudpickle, 23 einops, 24 flaxlib, 25 keras, 26 pytestCheckHook, 27 pytest-xdist, 28 sphinx, 29 tensorflow, 30 treescope, 31 32 # optional-dependencies 33 matplotlib, 34 35 writeScript, 36 tomlq, 37}: 38 39buildPythonPackage rec { 40 pname = "flax"; 41 version = "0.10.1"; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "google"; 46 repo = "flax"; 47 rev = "refs/tags/v${version}"; 48 hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4="; 49 }; 50 51 build-system = [ 52 jaxlib 53 setuptools-scm 54 ]; 55 56 dependencies = [ 57 jax 58 msgpack 59 numpy 60 optax 61 orbax-checkpoint 62 pyyaml 63 rich 64 tensorstore 65 typing-extensions 66 ]; 67 68 optional-dependencies = { 69 all = [ matplotlib ]; 70 }; 71 72 pythonImportsCheck = [ "flax" ]; 73 74 nativeCheckInputs = [ 75 cloudpickle 76 einops 77 flaxlib 78 keras 79 pytestCheckHook 80 pytest-xdist 81 sphinx 82 tensorflow 83 treescope 84 ]; 85 86 pytestFlagsArray = [ 87 "-W ignore::FutureWarning" 88 "-W ignore::DeprecationWarning" 89 ]; 90 91 disabledTestPaths = [ 92 # Docs test, needs extra deps + we're not interested in it. 93 "docs/_ext/codediff_test.py" 94 # The tests in `examples` are not designed to be executed from a single test 95 # session and thus either have the modules that conflict with each other or 96 # wrong import paths, depending on how they're invoked. Many tests also have 97 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 98 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 99 # would be limited anyway. 100 "examples/*" 101 "flax/nnx/examples/*" 102 # See https://github.com/google/flax/issues/3232. 103 "tests/jax_utils_test.py" 104 # Too old version of tensorflow: 105 # ModuleNotFoundError: No module named 'keras.api._v2' 106 "tests/tensorboard_test.py" 107 ]; 108 109 disabledTests = [ 110 # ValueError: Checkpoint path should be absolute 111 "test_overwrite_checkpoints0" 112 # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211 113 # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs 114 "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None 115 "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ... 116 ]; 117 118 passthru = { 119 updateScript = writeScript "update.sh" '' 120 nix-update flax # does not --build by default 121 nix-build . -A flax.src # src is essentially a passthru 122 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit 123 ''; 124 }; 125 126 meta = { 127 description = "Neural network library for JAX"; 128 homepage = "https://github.com/google/flax"; 129 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 130 license = lib.licenses.asl20; 131 maintainers = with lib.maintainers; [ ndl ]; 132 }; 133}