at 23.05-pre 2.8 kB view raw
1{ lib 2, absl-py 3, blas 4, buildPythonPackage 5, etils 6, fetchFromGitHub 7, jaxlib 8, lapack 9, matplotlib 10, numpy 11, opt-einsum 12, pytestCheckHook 13, pytest-xdist 14, pythonOlder 15, scipy 16, typing-extensions 17}: 18 19let 20 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 21in 22buildPythonPackage rec { 23 pname = "jax"; 24 version = "0.3.23"; 25 format = "setuptools"; 26 27 disabled = pythonOlder "3.7"; 28 29 src = fetchFromGitHub { 30 owner = "google"; 31 repo = pname; 32 rev = "jax-v${version}"; 33 hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4="; 34 }; 35 36 # jaxlib is _not_ included in propagatedBuildInputs because there are 37 # different versions of jaxlib depending on the desired target hardware. The 38 # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the 39 # CPU wheel is packaged. 40 propagatedBuildInputs = [ 41 absl-py 42 etils 43 numpy 44 opt-einsum 45 scipy 46 typing-extensions 47 ] ++ etils.optional-dependencies.epath; 48 49 checkInputs = [ 50 jaxlib 51 matplotlib 52 pytestCheckHook 53 pytest-xdist 54 ]; 55 56 # high parallelism will result in the tests getting stuck 57 dontUsePytestXdist = true; 58 59 # NOTE: Don't run the tests in the expiremental directory as they require flax 60 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 61 # Not a big deal, this is how the JAX docs suggest running the test suite 62 # anyhow. 63 pytestFlagsArray = [ 64 "--numprocesses=4" 65 "-W ignore::DeprecationWarning" 66 "tests/" 67 ]; 68 69 disabledTests = [ 70 # Exceeds tolerance when the machine is busy 71 "test_custom_linear_solve_aux" 72 ] ++ lib.optionals usingMKL [ 73 # See 74 # * https://github.com/google/jax/issues/9705 75 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 76 # * https://github.com/NixOS/nixpkgs/issues/161960 77 "test_custom_linear_solve_cholesky" 78 "test_custom_root_with_aux" 79 "testEigvalsGrad_shape" 80 ]; 81 82 # See https://github.com/google/jax/issues/11722. This is a temporary fix in 83 # order to unblock etils, and upgrading jax/jaxlib to the latest version. See 84 # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. 85 disabledTestPaths = [ 86 "tests/api_test.py" 87 "tests/core_test.py" 88 "tests/lax_numpy_indexing_test.py" 89 "tests/lax_numpy_test.py" 90 "tests/nn_test.py" 91 "tests/random_test.py" 92 "tests/sparse_test.py" 93 ]; 94 95 # As of 0.3.22, `import jax` does not work without jaxlib being installed. 96 pythonImportsCheck = [ ]; 97 98 meta = with lib; { 99 description = "Differentiate, compile, and transform Numpy code"; 100 homepage = "https://github.com/google/jax"; 101 license = licenses.asl20; 102 maintainers = with maintainers; [ samuela ]; 103 }; 104}