1{ stdenv
2, buildPythonPackage
3, dm-haiku
4, pytest-xdist
5, pytestCheckHook
6, tensorflow
7, tensorflow-datasets
8, flax
9, optax
10}:
11
12buildPythonPackage rec {
13 pname = "optax-tests";
14 inherit (optax) version;
15
16 src = optax.testsout;
17
18 dontBuild = true;
19 dontInstall = true;
20
21 checkInputs = [
22 dm-haiku
23 pytest-xdist
24 pytestCheckHook
25 tensorflow
26 tensorflow-datasets
27 flax
28 ];
29
30 disabledTestPaths = [
31 # See https://github.com/deepmind/optax/issues/323
32 "examples/lookahead_mnist_test.py"
33 ];
34
35}