at 23.11-beta 769 B view raw
1{ buildPythonPackage 2, fetchFromGitHub 3, jax 4, jaxlib 5, lib 6, pytestCheckHook 7}: 8 9buildPythonPackage rec { 10 pname = "jmp"; 11 version = "0.0.4"; 12 13 src = fetchFromGitHub { 14 owner = "deepmind"; 15 repo = pname; 16 rev = "refs/tags/v${version}"; 17 hash = "sha256-+PefZU1209vvf1SfF8DXiTvKYEnZ4y8iiIr8yKikx9Y="; 18 }; 19 20 # Wheel requires only `numpy`, but the import needs `jax`. 21 propagatedBuildInputs = [ 22 jax 23 ]; 24 25 pythonImportsCheck = [ 26 "jmp" 27 ]; 28 29 nativeCheckInputs = [ 30 jaxlib 31 pytestCheckHook 32 ]; 33 34 meta = with lib; { 35 description = "This library implements support for mixed precision training in JAX."; 36 homepage = "https://github.com/deepmind/jmp"; 37 license = licenses.asl20; 38 maintainers = with maintainers; [ ndl ]; 39 }; 40}