1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8 setuptools-scm,
9
10 # dependencies
11 jax,
12 msgpack,
13 numpy,
14 optax,
15 orbax-checkpoint,
16 pyyaml,
17 rich,
18 tensorstore,
19 typing-extensions,
20
21 # optional-dependencies
22 matplotlib,
23
24 # tests
25 cloudpickle,
26 keras,
27 einops,
28 flaxlib,
29 pytestCheckHook,
30 pytest-xdist,
31 sphinx,
32 tensorflow,
33 treescope,
34
35 writeScript,
36 tomlq,
37}:
38
39buildPythonPackage rec {
40 pname = "flax";
41 version = "0.10.6";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "flax";
47 tag = "v${version}";
48 hash = "sha256-HhepJp7y2YN05XcZhB/L08g+yOfTJPRzd2m4ALQJGvw=";
49 };
50
51 build-system = [
52 setuptools
53 setuptools-scm
54 ];
55
56 dependencies = [
57 flaxlib
58 jax
59 msgpack
60 numpy
61 optax
62 orbax-checkpoint
63 pyyaml
64 rich
65 tensorstore
66 treescope
67 typing-extensions
68 ];
69
70 optional-dependencies = {
71 all = [ matplotlib ];
72 };
73
74 pythonImportsCheck = [ "flax" ];
75
76 nativeCheckInputs = [
77 cloudpickle
78 keras
79 einops
80 pytestCheckHook
81 pytest-xdist
82 sphinx
83 tensorflow
84 ];
85
86 disabledTestPaths = [
87 # Docs test, needs extra deps + we're not interested in it.
88 "docs/_ext/codediff_test.py"
89
90 # The tests in `examples` are not designed to be executed from a single test
91 # session and thus either have the modules that conflict with each other or
92 # wrong import paths, depending on how they're invoked. Many tests also have
93 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
94 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
95 # would be limited anyway.
96 "examples/*"
97 ];
98
99 disabledTests = [
100 # AssertionError: [Chex] Function 'add' is traced > 1 times!
101 "PadShardUnpadTest"
102 ];
103
104 passthru = {
105 updateScript = writeScript "update.sh" ''
106 nix-update flax # does not --build by default
107 nix-build . -A flax.src # src is essentially a passthru
108 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
109 '';
110 };
111
112 meta = {
113 description = "Neural network library for JAX";
114 homepage = "https://github.com/google/flax";
115 changelog = "https://github.com/google/flax/releases/tag/v${version}";
116 license = lib.licenses.asl20;
117 maintainers = with lib.maintainers; [ ndl ];
118 };
119}