1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 absl-py, 7 flax, 8 jaxlib, 9 jmp, 10 numpy, 11 tabulate, 12 pytest-xdist, 13 pytestCheckHook, 14 bsuite, 15 chex, 16 cloudpickle, 17 dill, 18 dm-env, 19 dm-tree, 20 optax, 21 rlax, 22 tensorflow, 23}: 24 25let 26 dm-haiku = buildPythonPackage rec { 27 pname = "dm-haiku"; 28 version = "0.0.12"; 29 format = "setuptools"; 30 31 src = fetchFromGitHub { 32 owner = "deepmind"; 33 repo = "dm-haiku"; 34 rev = "refs/tags/v${version}"; 35 hash = "sha256-aJRXlMq4CNMH3ZSTDP8MgnVltdSc8l5raw4//KccL48="; 36 }; 37 38 patches = [ 39 # https://github.com/deepmind/dm-haiku/pull/672 40 (fetchpatch { 41 name = "fix-find-namespace-packages.patch"; 42 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch"; 43 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE="; 44 }) 45 ]; 46 47 propagatedBuildInputs = [ 48 absl-py 49 flax 50 jaxlib 51 jmp 52 numpy 53 tabulate 54 ]; 55 56 pythonImportsCheck = [ "haiku" ]; 57 58 nativeCheckInputs = [ 59 bsuite 60 chex 61 cloudpickle 62 dill 63 dm-env 64 dm-haiku 65 dm-tree 66 jaxlib 67 optax 68 pytest-xdist 69 pytestCheckHook 70 rlax 71 tensorflow 72 ]; 73 74 disabledTests = [ 75 # See https://github.com/deepmind/dm-haiku/issues/366. 76 "test_jit_Recurrent" 77 78 # Assertion errors 79 "testShapeChecking0" 80 "testShapeChecking1" 81 82 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough. 83 "test_reshape_convert" 84 85 # This test requires JAX support for double precision (64bit), but enabling this causes several 86 # other tests to fail. 87 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision 88 "test_doctest_haiku.experimental" 89 ]; 90 91 disabledTestPaths = [ 92 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough. 93 "haiku/_src/integration/jax2tf_test.py" 94 ]; 95 96 doCheck = false; 97 98 # check in passthru.tests.pytest to escape infinite recursion with bsuite 99 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: { 100 pname = "${pname}-tests"; 101 doCheck = true; 102 103 # We don't have to install because the only purpose 104 # of this passthru test is to, well, test. 105 # This fixes having to set `catchConflicts` to false. 106 dontInstall = true; 107 }); 108 109 meta = with lib; { 110 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet."; 111 homepage = "https://github.com/deepmind/dm-haiku"; 112 license = licenses.asl20; 113 maintainers = with maintainers; [ ndl ]; 114 }; 115 }; 116in 117dm-haiku