at 23.11-beta 2.1 kB view raw
1{ lib 2, buildPythonPackage 3, fetchFromGitHub 4, jaxlib 5, pythonRelaxDepsHook 6, setuptools-scm 7, jax 8, msgpack 9, numpy 10, optax 11, pyyaml 12, rich 13, tensorstore 14, typing-extensions 15, matplotlib 16, cloudpickle 17, einops 18, keras 19, pytest-xdist 20, pytestCheckHook 21, tensorflow 22}: 23 24buildPythonPackage rec { 25 pname = "flax"; 26 version = "0.7.5"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "google"; 31 repo = "flax"; 32 rev = "refs/tags/v${version}"; 33 hash = "sha256-NDah0ayQbiO1/sTU1DDf/crPq5oLTnSuosV7cFHlTM8="; 34 }; 35 36 nativeBuildInputs = [ 37 jaxlib 38 pythonRelaxDepsHook 39 setuptools-scm 40 ]; 41 42 propagatedBuildInputs = [ 43 jax 44 msgpack 45 numpy 46 optax 47 pyyaml 48 rich 49 tensorstore 50 typing-extensions 51 ]; 52 53 passthru.optional-dependencies = { 54 all = [ matplotlib ]; 55 }; 56 57 pythonImportsCheck = [ 58 "flax" 59 ]; 60 61 nativeCheckInputs = [ 62 cloudpickle 63 einops 64 keras 65 pytest-xdist 66 pytestCheckHook 67 tensorflow 68 ]; 69 70 pytestFlagsArray = [ 71 "-W ignore::FutureWarning" 72 "-W ignore::DeprecationWarning" 73 ]; 74 75 disabledTestPaths = [ 76 # Docs test, needs extra deps + we're not interested in it. 77 "docs/_ext/codediff_test.py" 78 79 # The tests in `examples` are not designed to be executed from a single test 80 # session and thus either have the modules that conflict with each other or 81 # wrong import paths, depending on how they're invoked. Many tests also have 82 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 83 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 84 # would be limited anyway. 85 "examples/*" 86 87 # See https://github.com/google/flax/issues/3232. 88 "tests/jax_utils_test.py" 89 90 # Requires orbax which is not packaged as of 2023-07-27. 91 "tests/checkpoints_test.py" 92 ]; 93 94 meta = with lib; { 95 description = "Neural network library for JAX"; 96 homepage = "https://github.com/google/flax"; 97 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 98 license = licenses.asl20; 99 maintainers = with maintainers; [ ndl ]; 100 }; 101}