at 23.05-pre 995 B view raw
1{ buildPythonPackage 2, fetchFromGitHub 3, jax 4, jaxlib 5, lib 6, pytestCheckHook 7}: 8 9buildPythonPackage rec { 10 pname = "jmp"; 11 # As of 2022-01-01, the latest stable version (0.0.2) fails tests with recent JAX versions, 12 # IIUC it's fixed in https://github.com/deepmind/jmp/commit/4969392f618d7733b265677143d8c81e44085867 13 version = "unstable-2021-10-03"; 14 15 src = fetchFromGitHub { 16 owner = "deepmind"; 17 repo = pname; 18 rev = "260e5ba01f46b10c579a61393e6c7e546aeae93e"; 19 sha256 = "sha256-BTHy/jNf6LeV+x3GTI9MDBWLK6A5z2Z1TQyBkHMTeuE="; 20 }; 21 22 # Wheel requires only `numpy`, but the import needs `jax`. 23 propagatedBuildInputs = [ 24 jax 25 ]; 26 27 pythonImportsCheck = [ 28 "jmp" 29 ]; 30 31 checkInputs = [ 32 jaxlib 33 pytestCheckHook 34 ]; 35 36 meta = with lib; { 37 description = "This library implements support for mixed precision training in JAX."; 38 homepage = "https://github.com/deepmind/jmp"; 39 license = licenses.asl20; 40 maintainers = with maintainers; [ ndl ]; 41 }; 42}