1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6 absl-py,
7 flax,
8 jaxlib,
9 jmp,
10 numpy,
11 tabulate,
12 pytest-xdist,
13 pytestCheckHook,
14 bsuite,
15 chex,
16 cloudpickle,
17 dill,
18 dm-env,
19 dm-tree,
20 optax,
21 rlax,
22 tensorflow,
23}:
24
25let
26 dm-haiku = buildPythonPackage rec {
27 pname = "dm-haiku";
28 version = "0.0.12";
29 format = "setuptools";
30
31 src = fetchFromGitHub {
32 owner = "deepmind";
33 repo = "dm-haiku";
34 rev = "refs/tags/v${version}";
35 hash = "sha256-aJRXlMq4CNMH3ZSTDP8MgnVltdSc8l5raw4//KccL48=";
36 };
37
38 patches = [
39 # https://github.com/deepmind/dm-haiku/pull/672
40 (fetchpatch {
41 name = "fix-find-namespace-packages.patch";
42 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch";
43 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE=";
44 })
45 ];
46
47 propagatedBuildInputs = [
48 absl-py
49 flax
50 jaxlib
51 jmp
52 numpy
53 tabulate
54 ];
55
56 pythonImportsCheck = [ "haiku" ];
57
58 nativeCheckInputs = [
59 bsuite
60 chex
61 cloudpickle
62 dill
63 dm-env
64 dm-haiku
65 dm-tree
66 jaxlib
67 optax
68 pytest-xdist
69 pytestCheckHook
70 rlax
71 tensorflow
72 ];
73
74 disabledTests = [
75 # See https://github.com/deepmind/dm-haiku/issues/366.
76 "test_jit_Recurrent"
77
78 # Assertion errors
79 "testShapeChecking0"
80 "testShapeChecking1"
81
82 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough.
83 "test_reshape_convert"
84
85 # This test requires JAX support for double precision (64bit), but enabling this causes several
86 # other tests to fail.
87 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
88 "test_doctest_haiku.experimental"
89 ];
90
91 disabledTestPaths = [
92 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough.
93 "haiku/_src/integration/jax2tf_test.py"
94 ];
95
96 doCheck = false;
97
98 # check in passthru.tests.pytest to escape infinite recursion with bsuite
99 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: {
100 pname = "${pname}-tests";
101 doCheck = true;
102
103 # We don't have to install because the only purpose
104 # of this passthru test is to, well, test.
105 # This fixes having to set `catchConflicts` to false.
106 dontInstall = true;
107 });
108
109 meta = with lib; {
110 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet.";
111 homepage = "https://github.com/deepmind/dm-haiku";
112 license = licenses.asl20;
113 maintainers = with maintainers; [ ndl ];
114 };
115 };
116in
117dm-haiku