1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 flit-core,
8
9 # dependencies
10 absl-py,
11 chex,
12 jax,
13 jaxlib,
14 numpy,
15 etils,
16
17 # tests
18 callPackage,
19}:
20
21buildPythonPackage rec {
22 pname = "optax";
23 version = "0.2.4";
24 pyproject = true;
25
26 src = fetchFromGitHub {
27 owner = "deepmind";
28 repo = "optax";
29 tag = "v${version}";
30 hash = "sha256-7UPWeo/Q9/tjewaM7HN8/e7U1U1QzAliuk95+9GOi0E=";
31 };
32
33 outputs = [
34 "out"
35 "testsout"
36 ];
37
38 build-system = [ flit-core ];
39
40 dependencies = [
41 absl-py
42 chex
43 etils
44 jax
45 jaxlib
46 numpy
47 ] ++ etils.optional-dependencies.epy;
48
49 postInstall = ''
50 mkdir $testsout
51 cp -R examples $testsout/examples
52 '';
53
54 pythonImportsCheck = [ "optax" ];
55
56 # check in passthru.tests.pytest to escape infinite recursion with flax
57 doCheck = false;
58
59 passthru.tests = {
60 pytest = callPackage ./tests.nix { };
61 };
62
63 meta = {
64 description = "Gradient processing and optimization library for JAX";
65 homepage = "https://github.com/deepmind/optax";
66 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}";
67 license = lib.licenses.asl20;
68 maintainers = with lib.maintainers; [ ndl ];
69 };
70}