1{ absl-py
2, buildPythonPackage
3, chex
4, fetchFromGitHub
5, jaxlib
6, lib
7, numpy
8, callPackage
9}:
10
11buildPythonPackage rec {
12 pname = "optax";
13 version = "0.1.3";
14
15 src = fetchFromGitHub {
16 owner = "deepmind";
17 repo = pname;
18 rev = "refs/tags/v${version}";
19 hash = "sha256-XAYztMBQpLBHNuNED/iodbwIMJSN/0GxdmTGQ5jD9Ws=";
20 };
21
22 outputs = [
23 "out"
24 "testsout"
25 ];
26
27 buildInputs = [ jaxlib ];
28
29 propagatedBuildInputs = [
30 absl-py
31 chex
32 numpy
33 ];
34
35 postInstall = ''
36 mkdir $testsout
37 cp -R examples $testsout/examples
38 '';
39
40 pythonImportsCheck = [
41 "optax"
42 ];
43
44 # check in passthru.tests.pytest to escape infinite recursion with flax
45 doCheck = false;
46
47 passthru.tests = {
48 pytest = callPackage ./tests.nix { };
49 };
50
51 meta = with lib; {
52 description = "Optax is a gradient processing and optimization library for JAX.";
53 homepage = "https://github.com/deepmind/optax";
54 license = licenses.asl20;
55 maintainers = with maintainers; [ ndl ];
56 };
57}