Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 blas, 4 buildPythonPackage, 5 callPackage, 6 setuptools, 7 importlib-metadata, 8 fetchFromGitHub, 9 jaxlib, 10 jaxlib-bin, 11 jaxlib-build, 12 hypothesis, 13 lapack, 14 matplotlib, 15 ml-dtypes, 16 numpy, 17 opt-einsum, 18 pytestCheckHook, 19 pytest-xdist, 20 pythonOlder, 21 scipy, 22 stdenv, 23}: 24 25let 26 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 27in 28buildPythonPackage rec { 29 pname = "jax"; 30 version = "0.4.28"; 31 pyproject = true; 32 33 disabled = pythonOlder "3.9"; 34 35 src = fetchFromGitHub { 36 owner = "google"; 37 repo = "jax"; 38 # google/jax contains tags for jax and jaxlib. Only use jax tags! 39 rev = "refs/tags/jax-v${version}"; 40 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 41 }; 42 43 nativeBuildInputs = [ setuptools ]; 44 45 # The version is automatically set to ".dev" if this variable is not set. 46 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 47 JAX_RELEASE = "1"; 48 49 # jaxlib is _not_ included in propagatedBuildInputs because there are 50 # different versions of jaxlib depending on the desired target hardware. The 51 # JAX project ships separate wheels for CPU, GPU, and TPU. 52 propagatedBuildInputs = [ 53 ml-dtypes 54 numpy 55 opt-einsum 56 scipy 57 ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; 58 59 nativeCheckInputs = [ 60 hypothesis 61 jaxlib 62 matplotlib 63 pytestCheckHook 64 pytest-xdist 65 ]; 66 67 # high parallelism will result in the tests getting stuck 68 dontUsePytestXdist = true; 69 70 # NOTE: Don't run the tests in the expiremental directory as they require flax 71 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 72 # Not a big deal, this is how the JAX docs suggest running the test suite 73 # anyhow. 74 pytestFlagsArray = [ 75 "--numprocesses=4" 76 "-W ignore::DeprecationWarning" 77 "tests/" 78 ]; 79 80 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 81 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 82 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 83 # NOTE: this doesn't seem to be an issue on linux 84 preCheck = lib.optionalString stdenv.isDarwin '' 85 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 86 ''; 87 88 disabledTests = 89 [ 90 # Exceeds tolerance when the machine is busy 91 "test_custom_linear_solve_aux" 92 # UserWarning: Explicitly requested dtype <class 'numpy.float64'> 93 # requested in astype is not available, and will be truncated to 94 # dtype float32. (With numpy 1.24) 95 "testKde3" 96 "testKde5" 97 "testKde6" 98 # Invokes python manually in a subprocess, which does not have the correct dependencies 99 # ImportError: This version of jax requires jaxlib version >= 0.4.19. 100 "test_no_log_spam" 101 ] 102 ++ lib.optionals usingMKL [ 103 # See 104 # * https://github.com/google/jax/issues/9705 105 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 106 # * https://github.com/NixOS/nixpkgs/issues/161960 107 "test_custom_linear_solve_cholesky" 108 "test_custom_root_with_aux" 109 "testEigvalsGrad_shape" 110 ] 111 ++ lib.optionals stdenv.isAarch64 [ 112 # See https://github.com/google/jax/issues/14793. 113 "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" 114 "testQdwhWithRandomMatrix3" 115 "testScanGrad_jit_scan" 116 117 # See https://github.com/google/jax/issues/17867. 118 "test_array" 119 "test_async" 120 "test_copy0" 121 "test_device_put" 122 "test_make_array_from_callback" 123 "test_make_array_from_single_device_arrays" 124 125 # Fails on some hardware due to some numerical error 126 # See https://github.com/google/jax/issues/18535 127 "testQdwhWithOnRankDeficientInput5" 128 ]; 129 130 disabledTestPaths = [ 131 # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba 132 "tests/linalg_test.py" 133 ] 134 ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ 135 # RuntimeWarning: invalid value encountered in cast 136 "tests/lax_test.py" 137 ]; 138 139 pythonImportsCheck = [ "jax" ]; 140 141 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not 142 # currently feasible within the nix build environment so we have to maintain 143 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230 144 # for a possible remedy to this situation. 145 # 146 # Run these tests with eg 147 # 148 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin 149 passthru.tests = { 150 test_cuda_jaxlibSource = callPackage ./test-cuda.nix { 151 jaxlib = jaxlib-build.override { cudaSupport = true; }; 152 }; 153 test_cuda_jaxlibBin = callPackage ./test-cuda.nix { 154 jaxlib = jaxlib-bin.override { cudaSupport = true; }; 155 }; 156 }; 157 158 # updater fails to pick the correct branch 159 passthru.skipBulkUpdate = true; 160 161 meta = with lib; { 162 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 163 longDescription = '' 164 This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, 165 e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. 166 ''; 167 homepage = "https://github.com/google/jax"; 168 license = licenses.asl20; 169 maintainers = with maintainers; [ samuela ]; 170 }; 171}