1{ buildPythonPackage
2, fetchFromGitHub
3, callPackage
4, lib
5, jmp
6, tabulate
7, jaxlib
8}:
9
10buildPythonPackage rec {
11 pname = "dm-haiku";
12 version = "0.0.9";
13
14 src = fetchFromGitHub {
15 owner = "deepmind";
16 repo = pname;
17 rev = "refs/tags/v${version}";
18 hash = "sha256-d5THbfMRrbBL/2sQ99l2yeaTI9gT+bSkcxmVdRJT5bA=";
19 };
20
21 outputs = [
22 "out"
23 "testsout"
24 ];
25
26 propagatedBuildInputs = [
27 jaxlib
28 jmp
29 tabulate
30 ];
31
32 pythonImportsCheck = [
33 "haiku"
34 ];
35
36 postInstall = ''
37 mkdir $testsout
38 cp -R examples $testsout/examples
39 '';
40
41 # check in passthru.tests.pytest to escape infinite recursion with bsuite
42 doCheck = false;
43
44 passthru.tests = {
45 pytest = callPackage ./tests.nix { };
46 };
47
48 meta = with lib; {
49 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet.";
50 homepage = "https://github.com/deepmind/dm-haiku";
51 license = licenses.asl20;
52 maintainers = with maintainers; [ ndl ];
53 };
54}