Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 pythonOlder, 6 fetchFromGitHub, 7 fetchpatch, 8 pytest-xdist, 9 pytestCheckHook, 10 setuptools, 11 absl-py, 12 cvxpy, 13 jax, 14 jaxlib, 15 matplotlib, 16 numpy, 17 optax, 18 scipy, 19 scikit-learn, 20}: 21 22buildPythonPackage rec { 23 pname = "jaxopt"; 24 version = "0.8.3"; 25 pyproject = true; 26 27 disabled = pythonOlder "3.8"; 28 29 src = fetchFromGitHub { 30 owner = "google"; 31 repo = "jaxopt"; 32 rev = "refs/tags/jaxopt-v${version}"; 33 hash = "sha256-T/BHSnuk3IRuLkBj3Hvb/tFIb7Au25jjQtvwL28OU1U="; 34 }; 35 36 patches = [ 37 # fix failing tests from scipy 1.12 update 38 # https://github.com/google/jaxopt/pull/574 39 (fetchpatch { 40 name = "scipy-1.12-fix-tests.patch"; 41 url = "https://github.com/google/jaxopt/commit/48b09dc4cc93b6bc7e6764ed5d333f9b57f3493b.patch"; 42 hash = "sha256-v+617W7AhxA1Dzz+DBtljA4HHl89bRTuGi1QfatobNY="; 43 }) 44 ]; 45 46 build-system = [ setuptools ]; 47 48 dependencies = [ 49 absl-py 50 jax 51 jaxlib 52 matplotlib 53 numpy 54 scipy 55 ]; 56 57 nativeCheckInputs = [ 58 pytest-xdist 59 pytestCheckHook 60 cvxpy 61 optax 62 scikit-learn 63 ]; 64 65 pythonImportsCheck = [ 66 "jaxopt" 67 "jaxopt.implicit_diff" 68 "jaxopt.linear_solve" 69 "jaxopt.loss" 70 "jaxopt.tree_util" 71 ]; 72 73 disabledTests = 74 [ 75 # https://github.com/google/jaxopt/issues/592 76 "test_solve_sparse" 77 ] 78 ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [ 79 # https://github.com/google/jaxopt/issues/577 80 "test_binary_logit_log_likelihood" 81 "test_solve_sparse" 82 "test_logreg_with_intercept_manual_loop3" 83 84 # https://github.com/google/jaxopt/issues/593 85 # Makes the test suite crash 86 "test_dtype_consistency" 87 # AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01 88 "test_multiclass_logreg6" 89 ]; 90 91 meta = with lib; { 92 homepage = "https://jaxopt.github.io"; 93 description = "Hardware accelerated, batchable and differentiable optimizers in JAX"; 94 changelog = "https://github.com/google/jaxopt/releases/tag/jaxopt-v${version}"; 95 license = licenses.asl20; 96 maintainers = with maintainers; [ bcdarwin ]; 97 }; 98}