nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{ buildPythonPackage
2, fetchFromGitHub
3, jaxlib
4, keras
5, lib
6, matplotlib
7, msgpack
8, numpy
9, optax
10, pytest-xdist
11, pytestCheckHook
12, tensorflow
13}:
14
15buildPythonPackage rec {
16 pname = "flax";
17 version = "0.4.1";
18
19 src = fetchFromGitHub {
20 owner = "google";
21 repo = pname;
22 rev = "v${version}";
23 sha256 = "0j5ngdndm9nm49gcda7m36qzwk5lcbi4jnij9fi96vld54ip6f6v";
24 };
25
26 buildInputs = [ jaxlib ];
27
28 propagatedBuildInputs = [
29 matplotlib
30 msgpack
31 numpy
32 optax
33 ];
34
35 pythonImportsCheck = [
36 "flax"
37 ];
38
39 checkInputs = [
40 keras
41 pytest-xdist
42 pytestCheckHook
43 tensorflow
44 ];
45
46 pytestFlagsArray = [
47 "-W ignore::FutureWarning"
48 "-W ignore::DeprecationWarning"
49 ];
50
51 disabledTestPaths = [
52 # Docs test, needs extra deps + we're not interested in it.
53 "docs/_ext/codediff_test.py"
54
55 # The tests in `examples` are not designed to be executed from a single test
56 # session and thus either have the modules that conflict with each other or
57 # wrong import paths, depending on how they're invoked. Many tests also have
58 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
59 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
60 # would be limited anyway.
61 "examples/*"
62 ];
63
64 meta = with lib; {
65 description = "Neural network library for JAX";
66 homepage = "https://github.com/google/flax";
67 license = licenses.asl20;
68 maintainers = with maintainers; [ ndl ];
69 };
70}