Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at gcc-offload 128 lines 3.0 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 setuptools, 7 absl-py, 8 flax, 9 jax, 10 jaxlib, 11 jmp, 12 numpy, 13 tabulate, 14 pytest-xdist, 15 pytestCheckHook, 16 bsuite, 17 chex, 18 cloudpickle, 19 dill, 20 dm-env, 21 dm-tree, 22 optax, 23 rlax, 24 tensorflow, 25}: 26 27let 28 dm-haiku = buildPythonPackage rec { 29 pname = "dm-haiku"; 30 version = "0.0.13"; 31 pyproject = true; 32 33 src = fetchFromGitHub { 34 owner = "deepmind"; 35 repo = "dm-haiku"; 36 rev = "refs/tags/v${version}"; 37 hash = "sha256-RJpQ9BzlbQ4X31XoJFnsZASiaC9fP2AdyuTAGINhMxs="; 38 }; 39 40 patches = [ 41 # https://github.com/deepmind/dm-haiku/pull/672 42 (fetchpatch { 43 name = "fix-find-namespace-packages.patch"; 44 url = "https://github.com/deepmind/dm-haiku/commit/728031721f77d9aaa260bba0eddd9200d107ba5d.patch"; 45 hash = "sha256-qV94TdJnphlnpbq+B0G3KTx5CFGPno+8FvHyu/aZeQE="; 46 }) 47 ]; 48 49 build-system = [ setuptools ]; 50 51 dependencies = [ 52 absl-py 53 jaxlib # implicit runtime dependency 54 jmp 55 numpy 56 tabulate 57 ]; 58 59 optional-dependencies = { 60 jax = [ 61 jax 62 jaxlib 63 ]; 64 flax = [ flax ]; 65 }; 66 67 pythonImportsCheck = [ "haiku" ]; 68 69 nativeCheckInputs = [ 70 bsuite 71 chex 72 cloudpickle 73 dill 74 dm-env 75 dm-haiku 76 dm-tree 77 jaxlib 78 optax 79 pytest-xdist 80 pytestCheckHook 81 rlax 82 tensorflow 83 ]; 84 85 disabledTests = [ 86 # See https://github.com/deepmind/dm-haiku/issues/366. 87 "test_jit_Recurrent" 88 89 # Assertion errors 90 "testShapeChecking0" 91 "testShapeChecking1" 92 93 # This test requires a more recent version of tensorflow. The current one (2.13) is not enough. 94 "test_reshape_convert" 95 96 # This test requires JAX support for double precision (64bit), but enabling this causes several 97 # other tests to fail. 98 # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision 99 "test_doctest_haiku.experimental" 100 ]; 101 102 disabledTestPaths = [ 103 # Those tests requires a more recent version of tensorflow. The current one (2.13) is not enough. 104 "haiku/_src/integration/jax2tf_test.py" 105 ]; 106 107 doCheck = false; 108 109 # check in passthru.tests.pytest to escape infinite recursion with bsuite 110 passthru.tests.pytest = dm-haiku.overridePythonAttrs (_: { 111 pname = "${pname}-tests"; 112 doCheck = true; 113 114 # We don't have to install because the only purpose 115 # of this passthru test is to, well, test. 116 # This fixes having to set `catchConflicts` to false. 117 dontInstall = true; 118 }); 119 120 meta = with lib; { 121 description = "Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet"; 122 homepage = "https://github.com/deepmind/dm-haiku"; 123 license = licenses.asl20; 124 maintainers = with maintainers; [ ndl ]; 125 }; 126 }; 127in 128dm-haiku