at 24.11-pre 5.5 kB view raw
1{ 2 lib, 3 blas, 4 buildPythonPackage, 5 callPackage, 6 setuptools, 7 importlib-metadata, 8 fetchFromGitHub, 9 jaxlib, 10 jaxlib-bin, 11 hypothesis, 12 lapack, 13 matplotlib, 14 ml-dtypes, 15 numpy, 16 opt-einsum, 17 pytestCheckHook, 18 pytest-xdist, 19 pythonOlder, 20 scipy, 21 stdenv, 22}: 23 24let 25 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 26 # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work 27 # fine. jaxlib is only used in the checkPhase, so switching backends does not 28 # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*. 29 jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib; 30in 31buildPythonPackage rec { 32 pname = "jax"; 33 version = "0.4.28"; 34 pyproject = true; 35 36 disabled = pythonOlder "3.9"; 37 38 src = fetchFromGitHub { 39 owner = "google"; 40 repo = "jax"; 41 # google/jax contains tags for jax and jaxlib. Only use jax tags! 42 rev = "refs/tags/jax-v${version}"; 43 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 44 }; 45 46 nativeBuildInputs = [ setuptools ]; 47 48 # The version is automatically set to ".dev" if this variable is not set. 49 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 50 JAX_RELEASE = "1"; 51 52 # jaxlib is _not_ included in propagatedBuildInputs because there are 53 # different versions of jaxlib depending on the desired target hardware. The 54 # JAX project ships separate wheels for CPU, GPU, and TPU. 55 propagatedBuildInputs = [ 56 ml-dtypes 57 numpy 58 opt-einsum 59 scipy 60 ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; 61 62 nativeCheckInputs = [ 63 hypothesis 64 jaxlib' 65 matplotlib 66 pytestCheckHook 67 pytest-xdist 68 ]; 69 70 # high parallelism will result in the tests getting stuck 71 dontUsePytestXdist = true; 72 73 # NOTE: Don't run the tests in the expiremental directory as they require flax 74 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 75 # Not a big deal, this is how the JAX docs suggest running the test suite 76 # anyhow. 77 pytestFlagsArray = [ 78 "--numprocesses=4" 79 "-W ignore::DeprecationWarning" 80 "tests/" 81 ]; 82 83 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 84 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 85 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 86 # NOTE: this doesn't seem to be an issue on linux 87 preCheck = lib.optionalString stdenv.isDarwin '' 88 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 89 ''; 90 91 disabledTests = 92 [ 93 # Exceeds tolerance when the machine is busy 94 "test_custom_linear_solve_aux" 95 # UserWarning: Explicitly requested dtype <class 'numpy.float64'> 96 # requested in astype is not available, and will be truncated to 97 # dtype float32. (With numpy 1.24) 98 "testKde3" 99 "testKde5" 100 "testKde6" 101 # Invokes python manually in a subprocess, which does not have the correct dependencies 102 # ImportError: This version of jax requires jaxlib version >= 0.4.19. 103 "test_no_log_spam" 104 ] 105 ++ lib.optionals usingMKL [ 106 # See 107 # * https://github.com/google/jax/issues/9705 108 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 109 # * https://github.com/NixOS/nixpkgs/issues/161960 110 "test_custom_linear_solve_cholesky" 111 "test_custom_root_with_aux" 112 "testEigvalsGrad_shape" 113 ] 114 ++ lib.optionals stdenv.isAarch64 [ 115 # See https://github.com/google/jax/issues/14793. 116 "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" 117 "testQdwhWithRandomMatrix3" 118 "testScanGrad_jit_scan" 119 120 # See https://github.com/google/jax/issues/17867. 121 "test_array" 122 "test_async" 123 "test_copy0" 124 "test_device_put" 125 "test_make_array_from_callback" 126 "test_make_array_from_single_device_arrays" 127 128 # Fails on some hardware due to some numerical error 129 # See https://github.com/google/jax/issues/18535 130 "testQdwhWithOnRankDeficientInput5" 131 ]; 132 133 disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ 134 # RuntimeWarning: invalid value encountered in cast 135 "tests/lax_test.py" 136 ]; 137 138 pythonImportsCheck = [ "jax" ]; 139 140 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not 141 # currently feasible within the nix build environment so we have to maintain 142 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230 143 # for a possible remedy to this situation. 144 # 145 # Run these tests with eg 146 # 147 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin 148 passthru.tests = { 149 test_cuda_jaxlibSource = callPackage ./test-cuda.nix { 150 jaxlib = jaxlib.override { cudaSupport = true; }; 151 }; 152 test_cuda_jaxlibBin = callPackage ./test-cuda.nix { 153 jaxlib = jaxlib-bin.override { cudaSupport = true; }; 154 }; 155 }; 156 157 # updater fails to pick the correct branch 158 passthru.skipBulkUpdate = true; 159 160 meta = with lib; { 161 description = "Differentiate, compile, and transform Numpy code"; 162 homepage = "https://github.com/google/jax"; 163 license = licenses.asl20; 164 maintainers = with maintainers; [ samuela ]; 165 }; 166}