nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8 setuptools-scm,
9
10 # dependencies
11 jax,
12 msgpack,
13 numpy,
14 optax,
15 orbax-checkpoint,
16 pyyaml,
17 rich,
18 tensorstore,
19 typing-extensions,
20
21 # optional-dependencies
22 matplotlib,
23
24 # tests
25 cloudpickle,
26 keras,
27 einops,
28 flaxlib,
29 pytestCheckHook,
30 pytest-xdist,
31 sphinx,
32 tensorflow,
33 treescope,
34
35 writeScript,
36 tomlq,
37}:
38
39buildPythonPackage rec {
40 pname = "flax";
41 version = "0.12.2";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "flax";
47 tag = "v${version}";
48 hash = "sha256-Wdfc35/iah98C5WNYZWiAd2FJUJlyGLJ8xELpuYD3GU=";
49 };
50
51 build-system = [
52 setuptools
53 setuptools-scm
54 ];
55
56 dependencies = [
57 flaxlib
58 jax
59 msgpack
60 numpy
61 optax
62 orbax-checkpoint
63 pyyaml
64 rich
65 tensorstore
66 treescope
67 typing-extensions
68 ];
69
70 optional-dependencies = {
71 all = [ matplotlib ];
72 };
73
74 pythonImportsCheck = [ "flax" ];
75
76 nativeCheckInputs = [
77 cloudpickle
78 keras
79 einops
80 pytestCheckHook
81 pytest-xdist
82 sphinx
83 tensorflow
84 ];
85
86 pytestFlags = [
87 # FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar.
88 "-Wignore::FutureWarning"
89 ];
90
91 disabledTestPaths = [
92 # Docs test, needs extra deps + we're not interested in it.
93 "docs/_ext/codediff_test.py"
94
95 # The tests in `examples` are not designed to be executed from a single test
96 # session and thus either have the modules that conflict with each other or
97 # wrong import paths, depending on how they're invoked. Many tests also have
98 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
99 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
100 # would be limited anyway.
101 "examples/*"
102 ];
103
104 disabledTests = [
105 # AssertionError: [Chex] Function 'add' is traced > 1 times!
106 "PadShardUnpadTest"
107
108 # AssertionError: nnx_model.kernel.value.sharding = NamedSharding(...
109 "test_linen_to_nnx_metadata"
110
111 # AssertionError: 'Linear_0' not found in State({})
112 "test_compact_basic"
113 # KeyError: 'intermediates'
114 "test_linen_submodule"
115 "test_pure_nnx_submodule"
116 # KeyError: 'counts
117 "test_mutable_state"
118 # AttributeError: 'Top' object has no attribute '_pytree__state'. Did you mean: '_pytree__flatten'?
119 "test_shared_modules"
120 # AttributeError: 'MLP' object has no attribute 'scope
121 "test_transforms"
122 ];
123
124 passthru = {
125 updateScript = writeScript "update.sh" ''
126 nix-update flax # does not --build by default
127 nix-build . -A flax.src # src is essentially a passthru
128 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
129 '';
130 };
131
132 meta = {
133 description = "Neural network library for JAX";
134 homepage = "https://github.com/google/flax";
135 changelog = "https://github.com/google/flax/releases/tag/v${version}";
136 license = lib.licenses.asl20;
137 maintainers = with lib.maintainers; [ ndl ];
138 };
139}