1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchFromGitHub
5, pytestCheckHook
6, absl-py
7, cvxpy
8, jax
9, jaxlib
10, matplotlib
11, numpy
12, optax
13, scipy
14, scikit-learn
15}:
16
17buildPythonPackage rec {
18 pname = "jaxopt";
19 version = "0.5.5";
20 format = "setuptools";
21
22 disabled = pythonOlder "3.7";
23
24 src = fetchFromGitHub {
25 owner = "google";
26 repo = pname;
27 rev = "refs/tags/${pname}-v${version}";
28 hash = "sha256-WOsr/Dvguu9/qX6+LMlAKM3EANtYPtDu8Uo2157+bs0=";
29 };
30
31 propagatedBuildInputs = [
32 absl-py
33 jax
34 jaxlib
35 matplotlib
36 numpy
37 scipy
38 ];
39
40 nativeCheckInputs = [
41 pytestCheckHook
42 cvxpy
43 optax
44 scikit-learn
45 ];
46
47 pythonImportsCheck = [
48 "jaxopt"
49 "jaxopt.implicit_diff"
50 "jaxopt.linear_solve"
51 "jaxopt.loss"
52 "jaxopt.tree_util"
53 ];
54
55 disabledTests = [
56 # Stack frame issue
57 "test_bisect"
58 ];
59
60 meta = with lib; {
61 homepage = "https://jaxopt.github.io";
62 description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
63 license = licenses.asl20;
64 maintainers = with maintainers; [ bcdarwin ];
65 };
66}