nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{ absl-py
2, buildPythonPackage
3, chex
4, dm-haiku
5, fetchFromGitHub
6, jaxlib
7, lib
8, numpy
9, pytest-xdist
10, pytestCheckHook
11, tensorflow
12, tensorflow-datasets
13}:
14
15buildPythonPackage rec {
16 pname = "optax";
17 version = "0.1.1";
18
19 src = fetchFromGitHub {
20 owner = "deepmind";
21 repo = pname;
22 rev = "v${version}";
23 hash = "sha256-s/BcqzhdfWzR61MStusUPQtuT4+t8NcC5gBGiGggFqw=";
24 };
25
26 buildInputs = [ jaxlib ];
27
28 propagatedBuildInputs = [
29 absl-py
30 chex
31 numpy
32 ];
33
34 checkInputs = [
35 dm-haiku
36 pytest-xdist
37 pytestCheckHook
38 tensorflow
39 tensorflow-datasets
40 ];
41
42 pythonImportsCheck = [
43 "optax"
44 ];
45
46 disabledTestPaths = [
47 # Requires `flax` which depends on `optax` creating circular dependency.
48 "optax/_src/equivalence_test.py"
49 # See https://github.com/deepmind/optax/issues/323.
50 "examples/lookahead_mnist_test.py"
51 ];
52
53 meta = with lib; {
54 description = "Optax is a gradient processing and optimization library for JAX.";
55 homepage = "https://github.com/deepmind/optax";
56 license = licenses.asl20;
57 maintainers = with maintainers; [ ndl ];
58 };
59}