1{
2 lib,
3 buildPythonPackage,
4 pythonAtLeast,
5 fetchFromGitHub,
6 fetchpatch,
7
8 # build-system
9 setuptools,
10
11 # dependencies
12 absl-py,
13 jaxlib,
14 jmp,
15 numpy,
16 tabulate,
17
18 # optional-dependencies
19 jax,
20 flax,
21
22 # tests
23 pytest-xdist,
24 pytestCheckHook,
25 bsuite,
26 chex,
27 cloudpickle,
28 dill,
29 dm-env,
30 dm-tree,
31 optax,
32 rlax,
33 tensorflow,
34}:
35
36let
37 dm-haiku = buildPythonPackage rec {
38 pname = "dm-haiku";
39 version = "0.0.13";
40 pyproject = true;
41
42 # ImportError: `haiku.experimental.flax` features require `flax` to be installed.
43 disabled = pythonAtLeast "3.13";
44
45 src = fetchFromGitHub {
46 owner = "deepmind";
47 repo = "dm-haiku";
48 tag = "v${version}";
49 hash = "sha256-RJpQ9BzlbQ4X31XoJFnsZASiaC9fP2AdyuTAGINhMxs=";
50 };
51
52 patches = [
53 # https://github.com/deepmind/dm-haiku/pull/672
54 (fetchpatch {
55 name = "fix-find-namespace-packages.patch";
56 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
57 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
58 })
59 ];
60
61 # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and
62 # see https://docs.jax.dev/en/latest/jax.extend.html for details.
63 # Already on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96
64 # TODO: remove at the next release
65 postPatch = ''
66 substituteInPlace haiku/_src/jaxpr_info.py \
67 --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \
68 --replace-fail "jax.core.Var" "jax.extend.core.Var" \
69 --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr"
70 '';
71
72 build-system = [ setuptools ];
73
74 dependencies = [
75 absl-py
76 jaxlib # implicit runtime dependency
77 jmp
78 numpy
79 tabulate
80 ];
81
82 optional-dependencies = {
83 jax = [
84 jax
85 jaxlib
86 ];
87 flax = [ flax ];
88 };
89
90 pythonImportsCheck = [ "haiku" ];
91
92 nativeCheckInputs = [
93 bsuite
94 chex
95 cloudpickle
96 dill
97 dm-env
98 dm-haiku
99 dm-tree
100 jaxlib
101 optax
102 pytest-xdist
103 pytestCheckHook
104 rlax
105 tensorflow
106 ];
107
108 disabledTests = [
109 # See https://github.com/deepmind/dm-haiku/issues/366.
110 "test_jit_Recurrent"
111
112 # Assertion errors
113 "testShapeChecking0"
114 "testShapeChecking1"
115
116 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
117 "test_reshape_convert"
118
119 # This test requires JAX support for double precision (64bit), but enabling this causes several
120 # other tests to fail.
121 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
122 "test_doctest_haiku.experimental"
123 ];
124
125 disabledTestPaths = [
126 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough.
127 "haiku/_src/integration/jax2tf_test.py"
128 ];
129
130 doCheck = false;
131
132 # check in passthru.tests.pytest to escape infinite recursion with bsuite
133 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
134 pname = "${pname}-tests";
135 doCheck = true;
136
137 # We don't have to install because the only purpose
138 # of this passthru test is to, well, test.
139 # This fixes having to set `catchConflicts` to false.
140 dontInstall = true;
141 });
142
143 meta = {
144 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet";
145 homepage = "https://github.com/deepmind/dm-haiku";
146 license = lib.licenses.asl20;
147 maintainers = with lib.maintainers; [ ndl ];
148 };
149 };
150in
151dm-haiku