1{ buildPythonPackage
2, fetchFromGitHub
3, jaxlib
4, jax
5, keras
6, lib
7, matplotlib
8, msgpack
9, numpy
10, optax
11, pytest-xdist
12, pytestCheckHook
13, tensorflow
14, fetchpatch
15, rich
16}:
17
18buildPythonPackage rec {
19 pname = "flax";
20 version = "0.6.5";
21
22 src = fetchFromGitHub {
23 owner = "google";
24 repo = pname;
25 rev = "refs/tags/v${version}";
26 hash = "sha256-Vv68BK83gTIKj0r9x+twdhqmRYziD0vxQCdHkYSeTak=";
27 };
28
29 buildInputs = [ jaxlib ];
30
31 propagatedBuildInputs = [
32 jax
33 matplotlib
34 msgpack
35 numpy
36 optax
37 rich
38 ];
39
40 pythonImportsCheck = [
41 "flax"
42 ];
43
44 nativeCheckInputs = [
45 keras
46 pytest-xdist
47 pytestCheckHook
48 tensorflow
49 ];
50
51 pytestFlagsArray = [
52 "-W ignore::FutureWarning"
53 "-W ignore::DeprecationWarning"
54 ];
55
56 disabledTestPaths = [
57 # Docs test, needs extra deps + we're not interested in it.
58 "docs/_ext/codediff_test.py"
59
60 # The tests in `examples` are not designed to be executed from a single test
61 # session and thus either have the modules that conflict with each other or
62 # wrong import paths, depending on how they're invoked. Many tests also have
63 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
64 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
65 # would be limited anyway.
66 "examples/*"
67 ];
68
69 disabledTests = [
70 # See https://github.com/google/flax/issues/2554.
71 "test_async_save_checkpoints"
72 "test_jax_array0"
73 "test_jax_array1"
74 "test_keep0"
75 "test_keep1"
76 "test_optimized_lstm_cell_matches_regular"
77 "test_overwrite_checkpoints"
78 "test_save_restore_checkpoints_target_empty"
79 "test_save_restore_checkpoints_target_none"
80 "test_save_restore_checkpoints_target_singular"
81 "test_save_restore_checkpoints_w_float_steps"
82 "test_save_restore_checkpoints"
83 ];
84
85 meta = with lib; {
86 description = "Neural network library for JAX";
87 homepage = "https://github.com/google/flax";
88 changelog = "https://github.com/google/flax/releases/tag/v${version}";
89 license = licenses.asl20;
90 maintainers = with maintainers; [ ndl ];
91 # Requires orbax which is not available
92 broken = true;
93 };
94}