at 25.11-pre 5.3 kB view raw
1{ 2 lib, 3 config, 4 stdenv, 5 blas, 6 lapack, 7 buildPythonPackage, 8 fetchFromGitHub, 9 cudaSupport ? config.cudaSupport, 10 11 # build-system 12 setuptools, 13 14 # dependencies 15 jaxlib, 16 ml-dtypes, 17 numpy, 18 opt-einsum, 19 scipy, 20 21 # optional-dependencies 22 jax-cuda12-plugin, 23 24 # tests 25 cloudpickle, 26 hypothesis, 27 matplotlib, 28 pytestCheckHook, 29 pytest-xdist, 30 31 # passthru 32 callPackage, 33 jax, 34 jaxlib-build, 35 jaxlib-bin, 36}: 37 38let 39 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 40in 41buildPythonPackage rec { 42 pname = "jax"; 43 version = "0.6.0"; 44 pyproject = true; 45 46 src = fetchFromGitHub { 47 owner = "google"; 48 repo = "jax"; 49 # google/jax contains tags for jax and jaxlib. Only use jax tags! 50 tag = "jax-v${version}"; 51 hash = "sha256-leadxLK21JF/cF/hMveUPgrwyumjR+zMm9Wsn7bWNLQ="; 52 }; 53 54 build-system = [ setuptools ]; 55 56 # The version is automatically set to ".dev" if this variable is not set. 57 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 58 JAX_RELEASE = "1"; 59 60 dependencies = [ 61 jaxlib 62 ml-dtypes 63 numpy 64 opt-einsum 65 scipy 66 ] ++ lib.optionals cudaSupport optional-dependencies.cuda; 67 68 optional-dependencies = rec { 69 cuda = [ jax-cuda12-plugin ]; 70 cuda12 = cuda; 71 cuda12_pip = cuda; 72 cuda12_local = cuda; 73 }; 74 75 nativeCheckInputs = [ 76 cloudpickle 77 hypothesis 78 matplotlib 79 pytestCheckHook 80 pytest-xdist 81 ]; 82 83 # high parallelism will result in the tests getting stuck 84 dontUsePytestXdist = true; 85 86 # NOTE: Don't run the tests in the experimental directory as they require flax 87 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 88 # Not a big deal, this is how the JAX docs suggest running the test suite 89 # anyhow. 90 pytestFlagsArray = 91 [ 92 "--numprocesses=4" 93 "-W ignore::DeprecationWarning" 94 "tests/" 95 ] 96 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 97 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 98 # reported at: https://github.com/jax-ml/jax/issues/26106 99 "--deselect tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch" 100 "--deselect tests/shape_poly_test.py::ShapePolyTest" 101 "--deselect tests/tree_util_test.py::TreeTest" 102 ]; 103 104 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 105 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 106 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 107 # NOTE: this doesn't seem to be an issue on linux 108 preCheck = lib.optionalString stdenv.hostPlatform.isDarwin '' 109 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 110 ''; 111 112 disabledTests = 113 [ 114 # Exceeds tolerance when the machine is busy 115 "test_custom_linear_solve_aux" 116 ] 117 ++ lib.optionals usingMKL [ 118 # See 119 # * https://github.com/google/jax/issues/9705 120 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 121 # * https://github.com/NixOS/nixpkgs/issues/161960 122 "test_custom_linear_solve_cholesky" 123 "test_custom_root_with_aux" 124 "testEigvalsGrad_shape" 125 ] 126 ++ lib.optionals stdenv.hostPlatform.isAarch64 [ 127 # Fails on some hardware due to some numerical error 128 # See https://github.com/google/jax/issues/18535 129 "testQdwhWithOnRankDeficientInput5" 130 ] 131 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 132 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 133 # reported at: https://github.com/jax-ml/jax/issues/26106 134 "testInAxesPyTreePrefixMismatchError" 135 "testInAxesPyTreePrefixMismatchErrorKwargs" 136 "testOutAxesPyTreePrefixMismatchError" 137 "test_tree_map" 138 "test_tree_prefix_error" 139 "test_vjp_rule_inconsistent_pytree_structures_error" 140 "test_vmap_in_axes_tree_prefix_error" 141 "test_vmap_mismatched_axis_sizes_error_message_issue_705" 142 ]; 143 144 pythonImportsCheck = [ "jax" ]; 145 146 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not 147 # currently feasible within the nix build environment so we have to maintain 148 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230 149 # for a possible remedy to this situation. 150 # 151 # Run these tests with eg 152 # 153 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin 154 passthru.tests = { 155 # jaxlib-build is broken as of 2024-12-20 156 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { 157 # jax = jax.override { jaxlib = jaxlib-build; }; 158 # }; 159 test_cuda_jaxlibBin = callPackage ./test-cuda.nix { 160 jax = jax.override { jaxlib = jaxlib-bin; }; 161 }; 162 }; 163 164 # updater fails to pick the correct branch 165 passthru.skipBulkUpdate = true; 166 167 meta = { 168 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 169 homepage = "https://github.com/google/jax"; 170 license = lib.licenses.asl20; 171 maintainers = with lib.maintainers; [ samuela ]; 172 }; 173}