1{ buildPythonPackage
2, fetchFromGitHub
3, jaxlib
4, jax
5, keras
6, lib
7, matplotlib
8, msgpack
9, numpy
10, optax
11, pytest-xdist
12, pytestCheckHook
13, tensorflow
14, fetchpatch
15, rich
16}:
17
18buildPythonPackage rec {
19 pname = "flax";
20 version = "0.6.1";
21
22 src = fetchFromGitHub {
23 owner = "google";
24 repo = pname;
25 rev = "refs/tags/v${version}";
26 sha256 = "sha256-fZiODo+izOwGjCCTvi11GvUG/VQL1DV9bNXKjvIIw4A=";
27 };
28
29 buildInputs = [ jaxlib ];
30
31 propagatedBuildInputs = [
32 jax
33 matplotlib
34 msgpack
35 numpy
36 optax
37 rich
38 ];
39
40 pythonImportsCheck = [
41 "flax"
42 ];
43
44 checkInputs = [
45 keras
46 pytest-xdist
47 pytestCheckHook
48 tensorflow
49 ];
50
51 pytestFlagsArray = [
52 "-W ignore::FutureWarning"
53 "-W ignore::DeprecationWarning"
54 ];
55
56 disabledTestPaths = [
57 # Docs test, needs extra deps + we're not interested in it.
58 "docs/_ext/codediff_test.py"
59
60 # The tests in `examples` are not designed to be executed from a single test
61 # session and thus either have the modules that conflict with each other or
62 # wrong import paths, depending on how they're invoked. Many tests also have
63 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
64 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
65 # would be limited anyway.
66 "examples/*"
67 ];
68
69 disabledTests = [
70 # See https://github.com/google/flax/issues/2554.
71 "test_async_save_checkpoints"
72 "test_jax_array0"
73 "test_jax_array1"
74 "test_keep0"
75 "test_keep1"
76 "test_optimized_lstm_cell_matches_regular"
77 "test_overwrite_checkpoints"
78 "test_save_restore_checkpoints_target_empty"
79 "test_save_restore_checkpoints_target_none"
80 "test_save_restore_checkpoints_target_singular"
81 "test_save_restore_checkpoints_w_float_steps"
82 "test_save_restore_checkpoints"
83 ];
84
85 meta = with lib; {
86 description = "Neural network library for JAX";
87 homepage = "https://github.com/google/flax";
88 license = licenses.asl20;
89 maintainers = with maintainers; [ ndl ];
90 };
91}