at 24.05-pre 3.9 kB view raw
1{ lib 2, blas 3, buildPythonPackage 4, setuptools 5, importlib-metadata 6, fetchFromGitHub 7, jaxlib 8, jaxlib-bin 9, lapack 10, matplotlib 11, ml-dtypes 12, numpy 13, opt-einsum 14, pytestCheckHook 15, pytest-xdist 16, pythonOlder 17, scipy 18, stdenv 19}: 20 21let 22 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 23 # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work 24 # fine. jaxlib is only used in the checkPhase, so switching backends does not 25 # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*. 26 jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib; 27in 28buildPythonPackage rec { 29 pname = "jax"; 30 version = "0.4.20"; 31 pyproject = true; 32 33 disabled = pythonOlder "3.9"; 34 35 src = fetchFromGitHub { 36 owner = "google"; 37 repo = "jax"; 38 # google/jax contains tags for jax and jaxlib. Only use jax tags! 39 rev = "refs/tags/${pname}-v${version}"; 40 hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU="; 41 }; 42 43 nativeBuildInputs = [ 44 setuptools 45 ]; 46 47 # The version is automatically set to ".dev" if this variable is not set. 48 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 49 JAX_RELEASE = "1"; 50 51 # jaxlib is _not_ included in propagatedBuildInputs because there are 52 # different versions of jaxlib depending on the desired target hardware. The 53 # JAX project ships separate wheels for CPU, GPU, and TPU. 54 propagatedBuildInputs = [ 55 ml-dtypes 56 numpy 57 opt-einsum 58 scipy 59 ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; 60 61 nativeCheckInputs = [ 62 jaxlib' 63 matplotlib 64 pytestCheckHook 65 pytest-xdist 66 ]; 67 68 # high parallelism will result in the tests getting stuck 69 dontUsePytestXdist = true; 70 71 # NOTE: Don't run the tests in the expiremental directory as they require flax 72 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 73 # Not a big deal, this is how the JAX docs suggest running the test suite 74 # anyhow. 75 pytestFlagsArray = [ 76 "--numprocesses=4" 77 "-W ignore::DeprecationWarning" 78 "tests/" 79 ]; 80 81 disabledTests = [ 82 # Exceeds tolerance when the machine is busy 83 "test_custom_linear_solve_aux" 84 # UserWarning: Explicitly requested dtype <class 'numpy.float64'> 85 # requested in astype is not available, and will be truncated to 86 # dtype float32. (With numpy 1.24) 87 "testKde3" 88 "testKde5" 89 "testKde6" 90 ] ++ lib.optionals usingMKL [ 91 # See 92 # * https://github.com/google/jax/issues/9705 93 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 94 # * https://github.com/NixOS/nixpkgs/issues/161960 95 "test_custom_linear_solve_cholesky" 96 "test_custom_root_with_aux" 97 "testEigvalsGrad_shape" 98 ] ++ lib.optionals stdenv.isAarch64 [ 99 # See https://github.com/google/jax/issues/14793. 100 "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" 101 "testQdwhWithRandomMatrix3" 102 "testScanGrad_jit_scan" 103 104 # See https://github.com/google/jax/issues/17867. 105 "test_array" 106 "test_async" 107 "test_copy0" 108 "test_device_put" 109 "test_make_array_from_callback" 110 "test_make_array_from_single_device_arrays" 111 112 # Fails on some hardware due to some numerical error 113 # See https://github.com/google/jax/issues/18535 114 "testQdwhWithOnRankDeficientInput5" 115 ]; 116 117 disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ 118 # RuntimeWarning: invalid value encountered in cast 119 "tests/lax_test.py" 120 ]; 121 122 pythonImportsCheck = [ "jax" ]; 123 124 meta = with lib; { 125 description = "Differentiate, compile, and transform Numpy code"; 126 homepage = "https://github.com/google/jax"; 127 license = licenses.asl20; 128 maintainers = with maintainers; [ samuela ]; 129 }; 130}