1{ buildPythonPackage
2, fetchFromGitHub
3, jax
4, jaxlib
5, lib
6, pytestCheckHook
7}:
8
9buildPythonPackage rec {
10 pname = "jmp";
11 version = "0.0.4";
12
13 src = fetchFromGitHub {
14 owner = "deepmind";
15 repo = pname;
16 rev = "refs/tags/v${version}";
17 hash = "sha256-+PefZU1209vvf1SfF8DXiTvKYEnZ4y8iiIr8yKikx9Y=";
18 };
19
20 # Wheel requires only `numpy`, but the import needs `jax`.
21 propagatedBuildInputs = [
22 jax
23 ];
24
25 pythonImportsCheck = [
26 "jmp"
27 ];
28
29 nativeCheckInputs = [
30 jaxlib
31 pytestCheckHook
32 ];
33
34 meta = with lib; {
35 description = "This library implements support for mixed precision training in JAX.";
36 homepage = "https://github.com/deepmind/jmp";
37 license = licenses.asl20;
38 maintainers = with maintainers; [ ndl ];
39 };
40}