1{ lib
2, buildPythonPackage
3, fetchFromGitHub
4, jaxlib
5, pythonRelaxDepsHook
6, setuptools-scm
7, jax
8, msgpack
9, numpy
10, optax
11, pyyaml
12, rich
13, tensorstore
14, typing-extensions
15, matplotlib
16, cloudpickle
17, einops
18, keras
19, pytest-xdist
20, pytestCheckHook
21, tensorflow
22}:
23
24buildPythonPackage rec {
25 pname = "flax";
26 version = "0.7.5";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "google";
31 repo = "flax";
32 rev = "refs/tags/v${version}";
33 hash = "sha256-NDah0ayQbiO1/sTU1DDf/crPq5oLTnSuosV7cFHlTM8=";
34 };
35
36 nativeBuildInputs = [
37 jaxlib
38 pythonRelaxDepsHook
39 setuptools-scm
40 ];
41
42 propagatedBuildInputs = [
43 jax
44 msgpack
45 numpy
46 optax
47 pyyaml
48 rich
49 tensorstore
50 typing-extensions
51 ];
52
53 passthru.optional-dependencies = {
54 all = [ matplotlib ];
55 };
56
57 pythonImportsCheck = [
58 "flax"
59 ];
60
61 nativeCheckInputs = [
62 cloudpickle
63 einops
64 keras
65 pytest-xdist
66 pytestCheckHook
67 tensorflow
68 ];
69
70 pytestFlagsArray = [
71 "-W ignore::FutureWarning"
72 "-W ignore::DeprecationWarning"
73 ];
74
75 disabledTestPaths = [
76 # Docs test, needs extra deps + we're not interested in it.
77 "docs/_ext/codediff_test.py"
78
79 # The tests in `examples` are not designed to be executed from a single test
80 # session and thus either have the modules that conflict with each other or
81 # wrong import paths, depending on how they're invoked. Many tests also have
82 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
83 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
84 # would be limited anyway.
85 "examples/*"
86
87 # See https://github.com/google/flax/issues/3232.
88 "tests/jax_utils_test.py"
89
90 # Requires orbax which is not packaged as of 2023-07-27.
91 "tests/checkpoints_test.py"
92 ];
93
94 meta = with lib; {
95 description = "Neural network library for JAX";
96 homepage = "https://github.com/google/flax";
97 changelog = "https://github.com/google/flax/releases/tag/v${version}";
98 license = licenses.asl20;
99 maintainers = with maintainers; [ ndl ];
100 };
101}