at 23.05-pre 1.0 kB view raw
1{ absl-py 2, buildPythonPackage 3, chex 4, fetchFromGitHub 5, jaxlib 6, lib 7, numpy 8, callPackage 9}: 10 11buildPythonPackage rec { 12 pname = "optax"; 13 version = "0.1.3"; 14 15 src = fetchFromGitHub { 16 owner = "deepmind"; 17 repo = pname; 18 rev = "refs/tags/v${version}"; 19 hash = "sha256-XAYztMBQpLBHNuNED/iodbwIMJSN/0GxdmTGQ5jD9Ws="; 20 }; 21 22 outputs = [ 23 "out" 24 "testsout" 25 ]; 26 27 buildInputs = [ jaxlib ]; 28 29 propagatedBuildInputs = [ 30 absl-py 31 chex 32 numpy 33 ]; 34 35 postInstall = '' 36 mkdir $testsout 37 cp -R examples $testsout/examples 38 ''; 39 40 pythonImportsCheck = [ 41 "optax" 42 ]; 43 44 # check in passthru.tests.pytest to escape infinite recursion with flax 45 doCheck = false; 46 47 passthru.tests = { 48 pytest = callPackage ./tests.nix { }; 49 }; 50 51 meta = with lib; { 52 description = "Optax is a gradient processing and optimization library for JAX."; 53 homepage = "https://github.com/deepmind/optax"; 54 license = licenses.asl20; 55 maintainers = with maintainers; [ ndl ]; 56 }; 57}