1{ buildPythonPackage
2, fetchFromGitHub
3, jaxlib
4, jax
5, keras
6, lib
7, matplotlib
8, msgpack
9, numpy
10, optax
11, pytest-xdist
12, pytestCheckHook
13, pythonRelaxDepsHook
14, tensorflow
15, tensorstore
16, fetchpatch
17, rich
18}:
19
20buildPythonPackage rec {
21 pname = "flax";
22 version = "0.7.4";
23
24 src = fetchFromGitHub {
25 owner = "google";
26 repo = pname;
27 rev = "refs/tags/v${version}";
28 hash = "sha256-i48omag/1Si3mCCGfsUD9qeejyeCLWzvvwKJqH8vm8k=";
29 };
30
31 nativeBuildInputs = [ jaxlib pythonRelaxDepsHook ];
32
33 propagatedBuildInputs = [
34 jax
35 matplotlib
36 msgpack
37 numpy
38 optax
39 rich
40 tensorstore
41 ];
42
43 # See https://github.com/google/flax/pull/2882.
44 pythonRemoveDeps = [ "orbax" ];
45
46 pythonImportsCheck = [
47 "flax"
48 ];
49
50 nativeCheckInputs = [
51 keras
52 pytest-xdist
53 pytestCheckHook
54 tensorflow
55 ];
56
57 pytestFlagsArray = [
58 "-W ignore::FutureWarning"
59 "-W ignore::DeprecationWarning"
60 ];
61
62 disabledTestPaths = [
63 # Docs test, needs extra deps + we're not interested in it.
64 "docs/_ext/codediff_test.py"
65
66 # The tests in `examples` are not designed to be executed from a single test
67 # session and thus either have the modules that conflict with each other or
68 # wrong import paths, depending on how they're invoked. Many tests also have
69 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
70 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
71 # would be limited anyway.
72 "examples/*"
73
74 # See https://github.com/google/flax/issues/3232.
75 "tests/jax_utils_test.py"
76
77 # Requires orbax which is not packaged as of 2023-07-27.
78 "tests/checkpoints_test.py"
79 ];
80
81 disabledTests = [
82 # See https://github.com/google/flax/issues/2554.
83 "test_async_save_checkpoints"
84 "test_jax_array0"
85 "test_jax_array1"
86 "test_keep0"
87 "test_keep1"
88 "test_optimized_lstm_cell_matches_regular"
89 "test_overwrite_checkpoints"
90 "test_save_restore_checkpoints_target_empty"
91 "test_save_restore_checkpoints_target_none"
92 "test_save_restore_checkpoints_target_singular"
93 "test_save_restore_checkpoints_w_float_steps"
94 "test_save_restore_checkpoints"
95 ];
96
97 meta = with lib; {
98 description = "Neural network library for JAX";
99 homepage = "https://github.com/google/flax";
100 changelog = "https://github.com/google/flax/releases/tag/v${version}";
101 license = licenses.asl20;
102 maintainers = with maintainers; [ ndl ];
103 };
104}