1{
2 lib,
3 absl-py,
4 buildPythonPackage,
5 flit-core,
6 chex,
7 fetchFromGitHub,
8 jaxlib,
9 numpy,
10 callPackage,
11 pythonOlder,
12}:
13
14buildPythonPackage rec {
15 pname = "optax";
16 version = "0.2.2";
17 pyproject = true;
18
19 disabled = pythonOlder "3.9";
20
21 src = fetchFromGitHub {
22 owner = "deepmind";
23 repo = "optax";
24 rev = "refs/tags/v${version}";
25 hash = "sha256-sBiKUuQR89mttc9Njrh1aeUJOYdlcF7Nlj3/+Y7OMb4=";
26 };
27
28 outputs = [
29 "out"
30 "testsout"
31 ];
32
33 nativeBuildInputs = [ flit-core ];
34
35 buildInputs = [ jaxlib ];
36
37 propagatedBuildInputs = [
38 absl-py
39 chex
40 numpy
41 ];
42
43 postInstall = ''
44 mkdir $testsout
45 cp -R examples $testsout/examples
46 '';
47
48 pythonImportsCheck = [ "optax" ];
49
50 # check in passthru.tests.pytest to escape infinite recursion with flax
51 doCheck = false;
52
53 passthru.tests = {
54 pytest = callPackage ./tests.nix { };
55 };
56
57 meta = with lib; {
58 description = "Gradient processing and optimization library for JAX";
59 homepage = "https://github.com/deepmind/optax";
60 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}";
61 license = licenses.asl20;
62 maintainers = with maintainers; [ ndl ];
63 };
64}