1{ buildPythonPackage
2, fetchFromGitHub
3, jax
4, jaxlib
5, lib
6, pytestCheckHook
7}:
8
9buildPythonPackage rec {
10 pname = "jmp";
11 # As of 2022-01-01, the latest stable version (0.0.2) fails tests with recent JAX versions,
12 # IIUC it's fixed in https://github.com/deepmind/jmp/commit/4969392f618d7733b265677143d8c81e44085867
13 version = "unstable-2021-10-03";
14
15 src = fetchFromGitHub {
16 owner = "deepmind";
17 repo = pname;
18 rev = "260e5ba01f46b10c579a61393e6c7e546aeae93e";
19 sha256 = "sha256-BTHy/jNf6LeV+x3GTI9MDBWLK6A5z2Z1TQyBkHMTeuE=";
20 };
21
22 # Wheel requires only `numpy`, but the import needs `jax`.
23 propagatedBuildInputs = [
24 jax
25 ];
26
27 pythonImportsCheck = [
28 "jmp"
29 ];
30
31 checkInputs = [
32 jaxlib
33 pytestCheckHook
34 ];
35
36 meta = with lib; {
37 description = "This library implements support for mixed precision training in JAX.";
38 homepage = "https://github.com/deepmind/jmp";
39 license = licenses.asl20;
40 maintainers = with maintainers; [ ndl ];
41 };
42}