1{
2 lib,
3 buildPythonPackage,
4 cloudpickle,
5 einops,
6 fetchFromGitHub,
7 jax,
8 jaxlib,
9 keras,
10 matplotlib,
11 msgpack,
12 numpy,
13 optax,
14 orbax-checkpoint,
15 pytest-xdist,
16 pytestCheckHook,
17 pythonOlder,
18 pythonRelaxDepsHook,
19 pyyaml,
20 rich,
21 setuptools-scm,
22 tensorflow,
23 tensorstore,
24 typing-extensions,
25}:
26
27buildPythonPackage rec {
28 pname = "flax";
29 version = "0.8.3";
30 pyproject = true;
31
32 disabled = pythonOlder "3.9";
33
34 src = fetchFromGitHub {
35 owner = "google";
36 repo = "flax";
37 rev = "refs/tags/v${version}";
38 hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
39 };
40
41 build-system = [
42 jaxlib
43 pythonRelaxDepsHook
44 setuptools-scm
45 ];
46
47 dependencies = [
48 jax
49 msgpack
50 numpy
51 optax
52 orbax-checkpoint
53 pyyaml
54 rich
55 tensorstore
56 typing-extensions
57 ];
58
59 passthru.optional-dependencies = {
60 all = [ matplotlib ];
61 };
62
63 pythonImportsCheck = [ "flax" ];
64
65 nativeCheckInputs = [
66 cloudpickle
67 einops
68 keras
69 pytest-xdist
70 pytestCheckHook
71 tensorflow
72 ];
73
74 pytestFlagsArray = [
75 "-W ignore::FutureWarning"
76 "-W ignore::DeprecationWarning"
77 ];
78
79 disabledTestPaths = [
80 # Docs test, needs extra deps + we're not interested in it.
81 "docs/_ext/codediff_test.py"
82 # The tests in `examples` are not designed to be executed from a single test
83 # session and thus either have the modules that conflict with each other or
84 # wrong import paths, depending on how they're invoked. Many tests also have
85 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
86 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
87 # would be limited anyway.
88 "examples/*"
89 "flax/experimental/nnx/examples/*"
90 # See https://github.com/google/flax/issues/3232.
91 "tests/jax_utils_test.py"
92 # Requires tree
93 "tests/tensorboard_test.py"
94 ];
95
96 disabledTests = [
97 # ValueError: Checkpoint path should be absolute
98 "test_overwrite_checkpoints0"
99 ];
100
101 meta = with lib; {
102 description = "Neural network library for JAX";
103 homepage = "https://github.com/google/flax";
104 changelog = "https://github.com/google/flax/releases/tag/v${version}";
105 license = licenses.asl20;
106 maintainers = with maintainers; [ ndl ];
107 };
108}