nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at r-updates 197 lines 5.9 kB view raw
1{ 2 lib, 3 config, 4 stdenv, 5 blas, 6 lapack, 7 buildPythonPackage, 8 fetchFromGitHub, 9 fetchpatch2, 10 cudaSupport ? config.cudaSupport, 11 12 # build-system 13 setuptools, 14 15 # dependencies 16 jaxlib, 17 ml-dtypes, 18 numpy, 19 opt-einsum, 20 scipy, 21 22 # optional-dependencies 23 jax-cuda12-plugin, 24 25 # tests 26 cloudpickle, 27 hypothesis, 28 matplotlib, 29 pytestCheckHook, 30 pytest-xdist, 31 32 # passthru 33 callPackage, 34 jax, 35 jaxlib-build, 36 jaxlib-bin, 37}: 38 39let 40 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 41in 42buildPythonPackage rec { 43 pname = "jax"; 44 version = "0.8.2"; 45 pyproject = true; 46 47 src = fetchFromGitHub { 48 owner = "google"; 49 repo = "jax"; 50 # google/jax contains tags for jax and jaxlib. Only use jax tags! 51 tag = "jax-v${version}"; 52 hash = "sha256-WKdFEhOxJPLjOXOChZbLRGcw0GFeg/TT/FT6M72C6bo="; 53 }; 54 55 patches = [ 56 # https://github.com/jax-ml/jax/pull/32840 57 (fetchpatch2 { 58 url = "https://github.com/Prince213/jax/commit/af5c211d49f3b99447db2252d2cc2b8e0fb54d1c.patch?full_index=1"; 59 hash = "sha256-ijEd+MDe91qyYfE+aMzR5rNmTeGadin6Io8PIfJWc3o="; 60 }) 61 ]; 62 63 build-system = [ setuptools ]; 64 65 # The version is automatically set to ".dev" if this variable is not set. 66 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 67 env.JAX_RELEASE = "1"; 68 69 dependencies = [ 70 jaxlib 71 ml-dtypes 72 numpy 73 opt-einsum 74 scipy 75 ] 76 ++ lib.optionals cudaSupport optional-dependencies.cuda; 77 78 optional-dependencies = rec { 79 cuda = [ jax-cuda12-plugin ]; 80 cuda12 = cuda; 81 cuda12_pip = cuda; 82 cuda12_local = cuda; 83 }; 84 85 nativeCheckInputs = [ 86 cloudpickle 87 hypothesis 88 matplotlib 89 pytestCheckHook 90 pytest-xdist 91 ]; 92 93 # high parallelism will result in the tests getting stuck 94 dontUsePytestXdist = true; 95 96 pytestFlags = [ 97 "--numprocesses=4" 98 "-Wignore::DeprecationWarning" 99 ]; 100 101 # NOTE: Don't run the tests in the experimental directory as they require flax 102 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 103 # Not a big deal, this is how the JAX docs suggest running the test suite 104 # anyhow. 105 enabledTestPaths = [ 106 "tests/" 107 ]; 108 109 disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [ 110 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 111 # reported at: https://github.com/jax-ml/jax/issues/26106 112 "tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch" 113 "tests/shape_poly_test.py::ShapePolyTest" 114 "tests/tree_util_test.py::TreeTest" 115 116 # Mostly AssertionError on numerical tests failing since 0.7.0 117 # https://github.com/jax-ml/jax/issues/31428 118 "tests/export_back_compat_test.py" 119 "tests/lax_numpy_test.py" 120 "tests/lax_scipy_test.py" 121 "tests/lax_test.py" 122 "tests/linalg_test.py" 123 ]; 124 125 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 126 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 127 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 128 # NOTE: this doesn't seem to be an issue on linux 129 preCheck = lib.optionalString stdenv.hostPlatform.isDarwin '' 130 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 131 ''; 132 133 disabledTests = [ 134 # Exceeds tolerance when the machine is busy 135 "test_custom_linear_solve_aux" 136 ] 137 ++ lib.optionals usingMKL [ 138 # See 139 # * https://github.com/google/jax/issues/9705 140 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 141 # * https://github.com/NixOS/nixpkgs/issues/161960 142 "test_custom_linear_solve_cholesky" 143 "test_custom_root_with_aux" 144 "testEigvalsGrad_shape" 145 ] 146 ++ lib.optionals stdenv.hostPlatform.isAarch64 [ 147 # Fails on some hardware due to some numerical error 148 # See https://github.com/google/jax/issues/18535 149 "testQdwhWithOnRankDeficientInput5" 150 ] 151 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 152 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 153 # reported at: https://github.com/jax-ml/jax/issues/26106 154 "testInAxesPyTreePrefixMismatchError" 155 "testInAxesPyTreePrefixMismatchErrorKwargs" 156 "testOutAxesPyTreePrefixMismatchError" 157 "test_tree_map" 158 "test_tree_prefix_error" 159 "test_vjp_rule_inconsistent_pytree_structures_error" 160 "test_vmap_in_axes_tree_prefix_error" 161 "test_vmap_mismatched_axis_sizes_error_message_issue_705" 162 ]; 163 164 pythonImportsCheck = [ "jax" ]; 165 166 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not 167 # currently feasible within the nix build environment so we have to maintain 168 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230 169 # for a possible remedy to this situation. 170 # 171 # Run these tests with eg 172 # 173 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin 174 passthru.tests = { 175 # jaxlib-build is broken as of 2024-12-20 176 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { 177 # jax = jax.override { jaxlib = jaxlib-build; }; 178 # }; 179 test_cuda_jaxlibBin = callPackage ./test-cuda.nix { 180 jax = jax.override { jaxlib = jaxlib-bin; }; 181 }; 182 }; 183 184 # updater fails to pick the correct branch 185 passthru.skipBulkUpdate = true; 186 187 meta = { 188 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 189 homepage = "https://github.com/google/jax"; 190 changelog = "https://docs.jax.dev/en/latest/changelog.html"; 191 license = lib.licenses.asl20; 192 maintainers = with lib.maintainers; [ 193 GaetanLepage 194 samuela 195 ]; 196 }; 197}