at 24.11-pre 1.2 kB view raw
1{ 2 lib, 3 absl-py, 4 buildPythonPackage, 5 flit-core, 6 chex, 7 fetchFromGitHub, 8 jaxlib, 9 numpy, 10 callPackage, 11 pythonOlder, 12}: 13 14buildPythonPackage rec { 15 pname = "optax"; 16 version = "0.2.2"; 17 pyproject = true; 18 19 disabled = pythonOlder "3.9"; 20 21 src = fetchFromGitHub { 22 owner = "deepmind"; 23 repo = "optax"; 24 rev = "refs/tags/v${version}"; 25 hash = "sha256-sBiKUuQR89mttc9Njrh1aeUJOYdlcF7Nlj3/+Y7OMb4="; 26 }; 27 28 outputs = [ 29 "out" 30 "testsout" 31 ]; 32 33 nativeBuildInputs = [ flit-core ]; 34 35 buildInputs = [ jaxlib ]; 36 37 propagatedBuildInputs = [ 38 absl-py 39 chex 40 numpy 41 ]; 42 43 postInstall = '' 44 mkdir $testsout 45 cp -R examples $testsout/examples 46 ''; 47 48 pythonImportsCheck = [ "optax" ]; 49 50 # check in passthru.tests.pytest to escape infinite recursion with flax 51 doCheck = false; 52 53 passthru.tests = { 54 pytest = callPackage ./tests.nix { }; 55 }; 56 57 meta = with lib; { 58 description = "Gradient processing and optimization library for JAX"; 59 homepage = "https://github.com/deepmind/optax"; 60 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}"; 61 license = licenses.asl20; 62 maintainers = with maintainers; [ ndl ]; 63 }; 64}