nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at 22.05 59 lines 1.1 kB view raw
1{ absl-py 2, buildPythonPackage 3, chex 4, dm-haiku 5, fetchFromGitHub 6, jaxlib 7, lib 8, numpy 9, pytest-xdist 10, pytestCheckHook 11, tensorflow 12, tensorflow-datasets 13}: 14 15buildPythonPackage rec { 16 pname = "optax"; 17 version = "0.1.1"; 18 19 src = fetchFromGitHub { 20 owner = "deepmind"; 21 repo = pname; 22 rev = "v${version}"; 23 hash = "sha256-s/BcqzhdfWzR61MStusUPQtuT4+t8NcC5gBGiGggFqw="; 24 }; 25 26 buildInputs = [ jaxlib ]; 27 28 propagatedBuildInputs = [ 29 absl-py 30 chex 31 numpy 32 ]; 33 34 checkInputs = [ 35 dm-haiku 36 pytest-xdist 37 pytestCheckHook 38 tensorflow 39 tensorflow-datasets 40 ]; 41 42 pythonImportsCheck = [ 43 "optax" 44 ]; 45 46 disabledTestPaths = [ 47 # Requires `flax` which depends on `optax` creating circular dependency. 48 "optax/_src/equivalence_test.py" 49 # See https://github.com/deepmind/optax/issues/323. 50 "examples/lookahead_mnist_test.py" 51 ]; 52 53 meta = with lib; { 54 description = "Optax is a gradient processing and optimization library for JAX."; 55 homepage = "https://github.com/deepmind/optax"; 56 license = licenses.asl20; 57 maintainers = with maintainers; [ ndl ]; 58 }; 59}