1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 jaxlib,
8 setuptools-scm,
9
10 # dependencies
11 jax,
12 msgpack,
13 numpy,
14 optax,
15 orbax-checkpoint,
16 pyyaml,
17 rich,
18 tensorstore,
19 typing-extensions,
20
21 # checks
22 cloudpickle,
23 einops,
24 flaxlib,
25 keras,
26 pytestCheckHook,
27 pytest-xdist,
28 sphinx,
29 tensorflow,
30 treescope,
31
32 # optional-dependencies
33 matplotlib,
34
35 writeScript,
36 tomlq,
37}:
38
39buildPythonPackage rec {
40 pname = "flax";
41 version = "0.10.1";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "flax";
47 rev = "refs/tags/v${version}";
48 hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4=";
49 };
50
51 build-system = [
52 jaxlib
53 setuptools-scm
54 ];
55
56 dependencies = [
57 jax
58 msgpack
59 numpy
60 optax
61 orbax-checkpoint
62 pyyaml
63 rich
64 tensorstore
65 typing-extensions
66 ];
67
68 optional-dependencies = {
69 all = [ matplotlib ];
70 };
71
72 pythonImportsCheck = [ "flax" ];
73
74 nativeCheckInputs = [
75 cloudpickle
76 einops
77 flaxlib
78 keras
79 pytestCheckHook
80 pytest-xdist
81 sphinx
82 tensorflow
83 treescope
84 ];
85
86 pytestFlagsArray = [
87 "-W ignore::FutureWarning"
88 "-W ignore::DeprecationWarning"
89 ];
90
91 disabledTestPaths = [
92 # Docs test, needs extra deps + we're not interested in it.
93 "docs/_ext/codediff_test.py"
94 # The tests in `examples` are not designed to be executed from a single test
95 # session and thus either have the modules that conflict with each other or
96 # wrong import paths, depending on how they're invoked. Many tests also have
97 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
98 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
99 # would be limited anyway.
100 "examples/*"
101 "flax/nnx/examples/*"
102 # See https://github.com/google/flax/issues/3232.
103 "tests/jax_utils_test.py"
104 # Too old version of tensorflow:
105 # ModuleNotFoundError: No module named 'keras.api._v2'
106 "tests/tensorboard_test.py"
107 ];
108
109 disabledTests = [
110 # ValueError: Checkpoint path should be absolute
111 "test_overwrite_checkpoints0"
112 # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
113 # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
114 "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
115 "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
116 ];
117
118 passthru = {
119 updateScript = writeScript "update.sh" ''
120 nix-update flax # does not --build by default
121 nix-build . -A flax.src # src is essentially a passthru
122 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
123 '';
124 };
125
126 meta = {
127 description = "Neural network library for JAX";
128 homepage = "https://github.com/google/flax";
129 changelog = "https://github.com/google/flax/releases/tag/v${version}";
130 license = lib.licenses.asl20;
131 maintainers = with lib.maintainers; [ ndl ];
132 };
133}