1{ lib
2, absl-py
3, buildPythonPackage
4, chex
5, fetchFromGitHub
6, jaxlib
7, numpy
8, callPackage
9, pythonOlder
10}:
11
12buildPythonPackage rec {
13 pname = "optax";
14 version = "0.1.7";
15 format = "setuptools";
16
17 disabled = pythonOlder "3.7";
18
19 src = fetchFromGitHub {
20 owner = "deepmind";
21 repo = pname;
22 rev = "refs/tags/v${version}";
23 hash = "sha256-zSMJxagPe2rkhrawJ+TWXUzk6V58IY6MhWmEqLVtOoA=";
24 };
25
26 outputs = [
27 "out"
28 "testsout"
29 ];
30
31 buildInputs = [
32 jaxlib
33 ];
34
35 propagatedBuildInputs = [
36 absl-py
37 chex
38 numpy
39 ];
40
41 postInstall = ''
42 mkdir $testsout
43 cp -R examples $testsout/examples
44 '';
45
46 pythonImportsCheck = [
47 "optax"
48 ];
49
50 # check in passthru.tests.pytest to escape infinite recursion with flax
51 doCheck = false;
52
53 passthru.tests = {
54 pytest = callPackage ./tests.nix { };
55 };
56
57 meta = with lib; {
58 description = "Gradient processing and optimization library for JAX";
59 homepage = "https://github.com/deepmind/optax";
60 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}";
61 license = licenses.asl20;
62 maintainers = with maintainers; [ ndl ];
63 };
64}