at 23.11-beta 1.2 kB view raw
1{ lib 2, buildPythonPackage 3, pythonOlder 4, fetchFromGitHub 5, pytestCheckHook 6, absl-py 7, cvxpy 8, jax 9, jaxlib 10, matplotlib 11, numpy 12, optax 13, scipy 14, scikit-learn 15}: 16 17buildPythonPackage rec { 18 pname = "jaxopt"; 19 version = "0.8.2"; 20 format = "setuptools"; 21 22 disabled = pythonOlder "3.8"; 23 24 src = fetchFromGitHub { 25 owner = "google"; 26 repo = "jaxopt"; 27 rev = "refs/tags/jaxopt-v${version}"; 28 hash = "sha256-uVOd3knoku5fKBNXOhCikGtjDuW3TtRqev94OM/8Pgk="; 29 }; 30 31 propagatedBuildInputs = [ 32 absl-py 33 jax 34 jaxlib 35 matplotlib 36 numpy 37 scipy 38 ]; 39 40 nativeCheckInputs = [ 41 pytestCheckHook 42 cvxpy 43 optax 44 scikit-learn 45 ]; 46 47 pythonImportsCheck = [ 48 "jaxopt" 49 "jaxopt.implicit_diff" 50 "jaxopt.linear_solve" 51 "jaxopt.loss" 52 "jaxopt.tree_util" 53 ]; 54 55 disabledTests = [ 56 # Stack frame issue 57 "test_bisect" 58 ]; 59 60 meta = with lib; { 61 homepage = "https://jaxopt.github.io"; 62 description = "Hardware accelerated, batchable and differentiable optimizers in JAX"; 63 changelog = "https://github.com/google/jaxopt/releases/tag/jaxopt-v${version}"; 64 license = licenses.asl20; 65 maintainers = with maintainers; [ bcdarwin ]; 66 }; 67}