at 23.11-beta 1.2 kB view raw
1{ lib 2, absl-py 3, buildPythonPackage 4, chex 5, fetchFromGitHub 6, jaxlib 7, numpy 8, callPackage 9, pythonOlder 10}: 11 12buildPythonPackage rec { 13 pname = "optax"; 14 version = "0.1.7"; 15 format = "setuptools"; 16 17 disabled = pythonOlder "3.7"; 18 19 src = fetchFromGitHub { 20 owner = "deepmind"; 21 repo = pname; 22 rev = "refs/tags/v${version}"; 23 hash = "sha256-zSMJxagPe2rkhrawJ+TWXUzk6V58IY6MhWmEqLVtOoA="; 24 }; 25 26 outputs = [ 27 "out" 28 "testsout" 29 ]; 30 31 buildInputs = [ 32 jaxlib 33 ]; 34 35 propagatedBuildInputs = [ 36 absl-py 37 chex 38 numpy 39 ]; 40 41 postInstall = '' 42 mkdir $testsout 43 cp -R examples $testsout/examples 44 ''; 45 46 pythonImportsCheck = [ 47 "optax" 48 ]; 49 50 # check in passthru.tests.pytest to escape infinite recursion with flax 51 doCheck = false; 52 53 passthru.tests = { 54 pytest = callPackage ./tests.nix { }; 55 }; 56 57 meta = with lib; { 58 description = "Gradient processing and optimization library for JAX"; 59 homepage = "https://github.com/deepmind/optax"; 60 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}"; 61 license = licenses.asl20; 62 maintainers = with maintainers; [ ndl ]; 63 }; 64}