at 25.11-pre 1.2 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 flit-core, 8 9 # dependencies 10 absl-py, 11 chex, 12 jax, 13 jaxlib, 14 numpy, 15 etils, 16 17 # tests 18 callPackage, 19}: 20 21buildPythonPackage rec { 22 pname = "optax"; 23 version = "0.2.4"; 24 pyproject = true; 25 26 src = fetchFromGitHub { 27 owner = "deepmind"; 28 repo = "optax"; 29 tag = "v${version}"; 30 hash = "sha256-7UPWeo/Q9/tjewaM7HN8/e7U1U1QzAliuk95+9GOi0E="; 31 }; 32 33 outputs = [ 34 "out" 35 "testsout" 36 ]; 37 38 build-system = [ flit-core ]; 39 40 dependencies = [ 41 absl-py 42 chex 43 etils 44 jax 45 jaxlib 46 numpy 47 ] ++ etils.optional-dependencies.epy; 48 49 postInstall = '' 50 mkdir $testsout 51 cp -R examples $testsout/examples 52 ''; 53 54 pythonImportsCheck = [ "optax" ]; 55 56 # check in passthru.tests.pytest to escape infinite recursion with flax 57 doCheck = false; 58 59 passthru.tests = { 60 pytest = callPackage ./tests.nix { }; 61 }; 62 63 meta = { 64 description = "Gradient processing and optimization library for JAX"; 65 homepage = "https://github.com/deepmind/optax"; 66 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}"; 67 license = lib.licenses.asl20; 68 maintainers = with lib.maintainers; [ ndl ]; 69 }; 70}