at master 3.8 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 pythonAtLeast, 5 fetchFromGitHub, 6 fetchpatch, 7 8 # build-system 9 setuptools, 10 11 # dependencies 12 absl-py, 13 jaxlib, 14 jmp, 15 numpy, 16 tabulate, 17 18 # optional-dependencies 19 jax, 20 flax, 21 22 # tests 23 pytest-xdist, 24 pytestCheckHook, 25 bsuite, 26 chex, 27 cloudpickle, 28 dill, 29 dm-env, 30 dm-tree, 31 optax, 32 rlax, 33 tensorflow, 34}: 35 36let 37 dm-haiku = buildPythonPackage rec { 38 pname = "dm-haiku"; 39 version = "0.0.13"; 40 pyproject = true; 41 42 # ImportError: `haiku.experimental.flax` features require `flax` to be installed. 43 disabled = pythonAtLeast "3.13"; 44 45 src = fetchFromGitHub { 46 owner = "deepmind"; 47 repo = "dm-haiku"; 48 tag = "v${version}"; 49 hash = "sha256-RJpQ9BzlbQ4X31XoJFnsZASiaC9fP2AdyuTAGINhMxs="; 50 }; 51 52 patches = [ 53 # https://github.com/deepmind/dm-haiku/pull/672 54 (fetchpatch { 55 name = "fix-find-namespace-packages.patch"; 56 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch"; 57 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE="; 58 }) 59 ]; 60 61 # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and 62 # see https://docs.jax.dev/en/latest/jax.extend.html for details. 63 # Already on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96 64 # TODO: remove at the next release 65 postPatch = '' 66 substituteInPlace haiku/_src/jaxpr_info.py \ 67 --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \ 68 --replace-fail "jax.core.Var" "jax.extend.core.Var" \ 69 --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr" 70 ''; 71 72 build-system = [ setuptools ]; 73 74 dependencies = [ 75 absl-py 76 jaxlib # implicit runtime dependency 77 jmp 78 numpy 79 tabulate 80 ]; 81 82 optional-dependencies = { 83 jax = [ 84 jax 85 jaxlib 86 ]; 87 flax = [ flax ]; 88 }; 89 90 pythonImportsCheck = [ "haiku" ]; 91 92 nativeCheckInputs = [ 93 bsuite 94 chex 95 cloudpickle 96 dill 97 dm-env 98 dm-haiku 99 dm-tree 100 jaxlib 101 optax 102 pytest-xdist 103 pytestCheckHook 104 rlax 105 tensorflow 106 ]; 107 108 disabledTests = [ 109 # See https://github.com/deepmind/dm-haiku/issues/366. 110 "test_jit_Recurrent" 111 112 # Assertion errors 113 "testShapeChecking0" 114 "testShapeChecking1" 115 116 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough. 117 "test_reshape_convert" 118 119 # This test requires JAX support for double precision (64bit), but enabling this causes several 120 # other tests to fail. 121 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision 122 "test_doctest_haiku.experimental" 123 ]; 124 125 disabledTestPaths = [ 126 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough. 127 "haiku/_src/integration/jax2tf_test.py" 128 ]; 129 130 doCheck = false; 131 132 # check in passthru.tests.pytest to escape infinite recursion with bsuite 133 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: { 134 pname = "${pname}-tests"; 135 doCheck = true; 136 137 # We don't have to install because the only purpose 138 # of this passthru test is to, well, test. 139 # This fixes having to set `catchConflicts` to false. 140 dontInstall = true; 141 }); 142 143 meta = { 144 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet"; 145 homepage = "https://github.com/deepmind/dm-haiku"; 146 license = lib.licenses.asl20; 147 maintainers = with lib.maintainers; [ ndl ]; 148 }; 149 }; 150in 151dm-haiku