1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6 setuptools,
7 absl-py,
8 flax,
9 jax,
10 jaxlib,
11 jmp,
12 numpy,
13 tabulate,
14 pytest-xdist,
15 pytestCheckHook,
16 bsuite,
17 chex,
18 cloudpickle,
19 dill,
20 dm-env,
21 dm-tree,
22 optax,
23 rlax,
24 tensorflow,
25}:
26
27let
28 dm-haiku = buildPythonPackage rec {
29 pname = "dm-haiku";
30 version = "0.0.13";
31 pyproject = true;
32
33 src = fetchFromGitHub {
34 owner = "deepmind";
35 repo = "dm-haiku";
36 rev = "refs/tags/v${version}";
37 hash = "sha256-RJpQ9BzlbQ4X31XoJFnsZASiaC9fP2AdyuTAGINhMxs=";
38 };
39
40 patches = [
41 # https://github.com/deepmind/dm-haiku/pull/672
42 (fetchpatch {
43 name = "fix-find-namespace-packages.patch";
44 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
45 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
46 })
47 ];
48
49 build-system = [ setuptools ];
50
51 dependencies = [
52 absl-py
53 jaxlib # implicit runtime dependency
54 jmp
55 numpy
56 tabulate
57 ];
58
59 optional-dependencies = {
60 jax = [
61 jax
62 jaxlib
63 ];
64 flax = [ flax ];
65 };
66
67 pythonImportsCheck = [ "haiku" ];
68
69 nativeCheckInputs = [
70 bsuite
71 chex
72 cloudpickle
73 dill
74 dm-env
75 dm-haiku
76 dm-tree
77 jaxlib
78 optax
79 pytest-xdist
80 pytestCheckHook
81 rlax
82 tensorflow
83 ];
84
85 disabledTests = [
86 # See https://github.com/deepmind/dm-haiku/issues/366.
87 "test_jit_Recurrent"
88
89 # Assertion errors
90 "testShapeChecking0"
91 "testShapeChecking1"
92
93 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
94 "test_reshape_convert"
95
96 # This test requires JAX support for double precision (64bit), but enabling this causes several
97 # other tests to fail.
98 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
99 "test_doctest_haiku.experimental"
100 ];
101
102 disabledTestPaths = [
103 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough.
104 "haiku/_src/integration/jax2tf_test.py"
105 ];
106
107 doCheck = false;
108
109 # check in passthru.tests.pytest to escape infinite recursion with bsuite
110 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
111 pname = "${pname}-tests";
112 doCheck = true;
113
114 # We don't have to install because the only purpose
115 # of this passthru test is to, well, test.
116 # This fixes having to set `catchConflicts` to false.
117 dontInstall = true;
118 });
119
120 meta = with lib; {
121 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet";
122 homepage = "https://github.com/deepmind/dm-haiku";
123 license = licenses.asl20;
124 maintainers = with maintainers; [ ndl ];
125 };
126 };
127in
128dm-haiku