nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at 22.05 100 lines 2.6 kB view raw
1{ lib 2, absl-py 3, blas 4, buildPythonPackage 5, fetchFromGitHub 6, fetchpatch 7, jaxlib 8, lapack 9, numpy 10, opt-einsum 11, pytestCheckHook 12, pytest-xdist 13, pythonOlder 14, scipy 15, typing-extensions 16}: 17 18let 19 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 20in 21buildPythonPackage rec { 22 pname = "jax"; 23 version = "0.3.6"; 24 format = "setuptools"; 25 26 disabled = pythonOlder "3.7"; 27 28 src = fetchFromGitHub { 29 owner = "google"; 30 repo = pname; 31 rev = "jax-v${version}"; 32 hash = "sha256-eGdAEZFHadNTHgciP4KMYHdwksz9g6un0Ar+A/KV5TE="; 33 }; 34 35 patches = [ 36 # See https://github.com/google/jax/issues/7944 37 ./cache-fix.patch 38 39 # See https://github.com/google/jax/issues/10292 40 (fetchpatch { 41 url = "https://github.com/google/jax/commit/cadc8046d56e0c1433cf48a2f106947d5f4ecbfd.patch"; 42 hash = "sha256-jrpIqt4LzWAswt/Cpwtfa5d1Yn31HcXkVH3ETmaigA0="; 43 }) 44 ]; 45 46 # jaxlib is _not_ included in propagatedBuildInputs because there are 47 # different versions of jaxlib depending on the desired target hardware. The 48 # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the 49 # CPU wheel is packaged. 50 propagatedBuildInputs = [ 51 absl-py 52 numpy 53 opt-einsum 54 scipy 55 typing-extensions 56 ]; 57 58 checkInputs = [ 59 jaxlib 60 pytestCheckHook 61 pytest-xdist 62 ]; 63 64 # high parallelism will result in the tests getting stuck 65 dontUsePytestXdist = true; 66 67 # NOTE: Don't run the tests in the expiremental directory as they require flax 68 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 69 # Not a big deal, this is how the JAX docs suggest running the test suite 70 # anyhow. 71 pytestFlagsArray = [ 72 "--numprocesses=4" 73 "-W ignore::DeprecationWarning" 74 "tests/" 75 ]; 76 77 disabledTests = [ 78 # Exceeds tolerance when the machine is busy 79 "test_custom_linear_solve_aux" 80 ] ++ lib.optionals usingMKL [ 81 # See 82 # * https://github.com/google/jax/issues/9705 83 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 84 # * https://github.com/NixOS/nixpkgs/issues/161960 85 "test_custom_linear_solve_cholesky" 86 "test_custom_root_with_aux" 87 "testEigvalsGrad_shape" 88 ]; 89 90 pythonImportsCheck = [ 91 "jax" 92 ]; 93 94 meta = with lib; { 95 description = "Differentiate, compile, and transform Numpy code"; 96 homepage = "https://github.com/google/jax"; 97 license = licenses.asl20; 98 maintainers = with maintainers; [ samuela ]; 99 }; 100}