1{
2 lib,
3 blas,
4 buildPythonPackage,
5 callPackage,
6 setuptools,
7 importlib-metadata,
8 fetchFromGitHub,
9 jaxlib,
10 jaxlib-bin,
11 jaxlib-build,
12 hypothesis,
13 lapack,
14 matplotlib,
15 ml-dtypes,
16 numpy,
17 opt-einsum,
18 pytestCheckHook,
19 pytest-xdist,
20 pythonOlder,
21 scipy,
22 stdenv,
23}:
24
25let
26 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
27in
28buildPythonPackage rec {
29 pname = "jax";
30 version = "0.4.28";
31 pyproject = true;
32
33 disabled = pythonOlder "3.9";
34
35 src = fetchFromGitHub {
36 owner = "google";
37 repo = "jax";
38 # google/jax contains tags for jax and jaxlib. Only use jax tags!
39 rev = "refs/tags/jax-v${version}";
40 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
41 };
42
43 nativeBuildInputs = [ setuptools ];
44
45 # The version is automatically set to ".dev" if this variable is not set.
46 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
47 JAX_RELEASE = "1";
48
49 # jaxlib is _not_ included in propagatedBuildInputs because there are
50 # different versions of jaxlib depending on the desired target hardware. The
51 # JAX project ships separate wheels for CPU, GPU, and TPU.
52 propagatedBuildInputs = [
53 ml-dtypes
54 numpy
55 opt-einsum
56 scipy
57 ] ++ lib.optional (pythonOlder "3.10") importlib-metadata;
58
59 nativeCheckInputs = [
60 hypothesis
61 jaxlib
62 matplotlib
63 pytestCheckHook
64 pytest-xdist
65 ];
66
67 # high parallelism will result in the tests getting stuck
68 dontUsePytestXdist = true;
69
70 # NOTE: Don't run the tests in the expiremental directory as they require flax
71 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
72 # Not a big deal, this is how the JAX docs suggest running the test suite
73 # anyhow.
74 pytestFlagsArray = [
75 "--numprocesses=4"
76 "-W ignore::DeprecationWarning"
77 "tests/"
78 ];
79
80 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
81 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
82 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
83 # NOTE: this doesn't seem to be an issue on linux
84 preCheck = lib.optionalString stdenv.isDarwin ''
85 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
86 '';
87
88 disabledTests =
89 [
90 # Exceeds tolerance when the machine is busy
91 "test_custom_linear_solve_aux"
92 # UserWarning: Explicitly requested dtype <class 'numpy.float64'>
93 # requested in astype is not available, and will be truncated to
94 # dtype float32. (With numpy 1.24)
95 "testKde3"
96 "testKde5"
97 "testKde6"
98 # Invokes python manually in a subprocess, which does not have the correct dependencies
99 # ImportError: This version of jax requires jaxlib version >= 0.4.19.
100 "test_no_log_spam"
101 ]
102 ++ lib.optionals usingMKL [
103 # See
104 # * https://github.com/google/jax/issues/9705
105 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
106 # * https://github.com/NixOS/nixpkgs/issues/161960
107 "test_custom_linear_solve_cholesky"
108 "test_custom_root_with_aux"
109 "testEigvalsGrad_shape"
110 ]
111 ++ lib.optionals stdenv.isAarch64 [
112 # See https://github.com/google/jax/issues/14793.
113 "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
114 "testQdwhWithRandomMatrix3"
115 "testScanGrad_jit_scan"
116
117 # See https://github.com/google/jax/issues/17867.
118 "test_array"
119 "test_async"
120 "test_copy0"
121 "test_device_put"
122 "test_make_array_from_callback"
123 "test_make_array_from_single_device_arrays"
124
125 # Fails on some hardware due to some numerical error
126 # See https://github.com/google/jax/issues/18535
127 "testQdwhWithOnRankDeficientInput5"
128 ];
129
130 disabledTestPaths = [
131 # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
132 "tests/linalg_test.py"
133 ]
134 ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
135 # RuntimeWarning: invalid value encountered in cast
136 "tests/lax_test.py"
137 ];
138
139 pythonImportsCheck = [ "jax" ];
140
141 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not
142 # currently feasible within the nix build environment so we have to maintain
143 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230
144 # for a possible remedy to this situation.
145 #
146 # Run these tests with eg
147 #
148 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
149 passthru.tests = {
150 test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
151 jaxlib = jaxlib-build.override { cudaSupport = true; };
152 };
153 test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
154 jaxlib = jaxlib-bin.override { cudaSupport = true; };
155 };
156 };
157
158 # updater fails to pick the correct branch
159 passthru.skipBulkUpdate = true;
160
161 meta = with lib; {
162 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
163 longDescription = ''
164 This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
165 e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
166 '';
167 homepage = "https://github.com/google/jax";
168 license = licenses.asl20;
169 maintainers = with maintainers; [ samuela ];
170 };
171}