1{
2 buildPythonPackage,
3 fetchFromGitHub,
4 jax,
5 jaxlib,
6 lib,
7 pytestCheckHook,
8}:
9
10buildPythonPackage rec {
11 pname = "jmp";
12 version = "0.0.4";
13 format = "setuptools";
14
15 src = fetchFromGitHub {
16 owner = "deepmind";
17 repo = "jmp";
18 tag = "v${version}";
19 hash = "sha256-+PefZU1209vvf1SfF8DXiTvKYEnZ4y8iiIr8yKikx9Y=";
20 };
21
22 # Wheel requires only `numpy`, but the import needs `jax`.
23 propagatedBuildInputs = [ jax ];
24
25 pythonImportsCheck = [ "jmp" ];
26
27 nativeCheckInputs = [
28 jaxlib
29 pytestCheckHook
30 ];
31
32 meta = with lib; {
33 description = "This library implements support for mixed precision training in JAX";
34 homepage = "https://github.com/deepmind/jmp";
35 license = licenses.asl20;
36 maintainers = with maintainers; [ ndl ];
37 };
38}