1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchFromGitHub
5, pytestCheckHook
6, absl-py
7, cvxpy
8, jax
9, jaxlib
10, matplotlib
11, numpy
12, optax
13, scipy
14, scikit-learn
15}:
16
17buildPythonPackage rec {
18 pname = "jaxopt";
19 version = "0.5.5";
20 disabled = pythonOlder "3.7";
21
22 src = fetchFromGitHub {
23 owner = "google";
24 repo = pname;
25 rev = "refs/tags/${pname}-v${version}";
26 hash = "sha256-WOsr/Dvguu9/qX6+LMlAKM3EANtYPtDu8Uo2157+bs0=";
27 };
28
29 propagatedBuildInputs = [
30 absl-py
31 jax
32 jaxlib
33 matplotlib
34 numpy
35 scipy
36 ];
37
38 nativeCheckInputs = [
39 pytestCheckHook
40 cvxpy
41 optax
42 scikit-learn
43 ];
44
45 pythonImportsCheck = [
46 "jaxopt"
47 "jaxopt.implicit_diff"
48 "jaxopt.linear_solve"
49 "jaxopt.loss"
50 "jaxopt.tree_util"
51 ];
52
53 meta = with lib; {
54 homepage = "https://jaxopt.github.io";
55 description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
56 license = licenses.asl20;
57 maintainers = with maintainers; [ bcdarwin ];
58 };
59}