1{ lib
2, buildPythonPackage
3, pythonOlder
4, fetchFromGitHub
5, pytestCheckHook
6, absl-py
7, cvxpy
8, jax
9, jaxlib
10, matplotlib
11, numpy
12, optax
13, scipy
14, scikit-learn
15}:
16
17buildPythonPackage rec {
18 pname = "jaxopt";
19 version = "0.8.2";
20 format = "setuptools";
21
22 disabled = pythonOlder "3.8";
23
24 src = fetchFromGitHub {
25 owner = "google";
26 repo = "jaxopt";
27 rev = "refs/tags/jaxopt-v${version}";
28 hash = "sha256-uVOd3knoku5fKBNXOhCikGtjDuW3TtRqev94OM/8Pgk=";
29 };
30
31 propagatedBuildInputs = [
32 absl-py
33 jax
34 jaxlib
35 matplotlib
36 numpy
37 scipy
38 ];
39
40 nativeCheckInputs = [
41 pytestCheckHook
42 cvxpy
43 optax
44 scikit-learn
45 ];
46
47 pythonImportsCheck = [
48 "jaxopt"
49 "jaxopt.implicit_diff"
50 "jaxopt.linear_solve"
51 "jaxopt.loss"
52 "jaxopt.tree_util"
53 ];
54
55 disabledTests = [
56 # Stack frame issue
57 "test_bisect"
58 ];
59
60 meta = with lib; {
61 homepage = "https://jaxopt.github.io";
62 description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
63 changelog = "https://github.com/google/jaxopt/releases/tag/jaxopt-v${version}";
64 license = licenses.asl20;
65 maintainers = with maintainers; [ bcdarwin ];
66 };
67}