1{ lib
2, absl-py
3, blas
4, buildPythonPackage
5, etils
6, fetchFromGitHub
7, jaxlib
8, lapack
9, matplotlib
10, numpy
11, opt-einsum
12, pytestCheckHook
13, pytest-xdist
14, pythonOlder
15, scipy
16, typing-extensions
17}:
18
19let
20 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
21in
22buildPythonPackage rec {
23 pname = "jax";
24 version = "0.3.23";
25 format = "setuptools";
26
27 disabled = pythonOlder "3.7";
28
29 src = fetchFromGitHub {
30 owner = "google";
31 repo = pname;
32 rev = "jax-v${version}";
33 hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4=";
34 };
35
36 # jaxlib is _not_ included in propagatedBuildInputs because there are
37 # different versions of jaxlib depending on the desired target hardware. The
38 # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
39 # CPU wheel is packaged.
40 propagatedBuildInputs = [
41 absl-py
42 etils
43 numpy
44 opt-einsum
45 scipy
46 typing-extensions
47 ] ++ etils.optional-dependencies.epath;
48
49 checkInputs = [
50 jaxlib
51 matplotlib
52 pytestCheckHook
53 pytest-xdist
54 ];
55
56 # high parallelism will result in the tests getting stuck
57 dontUsePytestXdist = true;
58
59 # NOTE: Don't run the tests in the expiremental directory as they require flax
60 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
61 # Not a big deal, this is how the JAX docs suggest running the test suite
62 # anyhow.
63 pytestFlagsArray = [
64 "--numprocesses=4"
65 "-W ignore::DeprecationWarning"
66 "tests/"
67 ];
68
69 disabledTests = [
70 # Exceeds tolerance when the machine is busy
71 "test_custom_linear_solve_aux"
72 ] ++ lib.optionals usingMKL [
73 # See
74 # * https://github.com/google/jax/issues/9705
75 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
76 # * https://github.com/NixOS/nixpkgs/issues/161960
77 "test_custom_linear_solve_cholesky"
78 "test_custom_root_with_aux"
79 "testEigvalsGrad_shape"
80 ];
81
82 # See https://github.com/google/jax/issues/11722. This is a temporary fix in
83 # order to unblock etils, and upgrading jax/jaxlib to the latest version. See
84 # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
85 disabledTestPaths = [
86 "tests/api_test.py"
87 "tests/core_test.py"
88 "tests/lax_numpy_indexing_test.py"
89 "tests/lax_numpy_test.py"
90 "tests/nn_test.py"
91 "tests/random_test.py"
92 "tests/sparse_test.py"
93 ];
94
95 # As of 0.3.22, `import jax` does not work without jaxlib being installed.
96 pythonImportsCheck = [ ];
97
98 meta = with lib; {
99 description = "Differentiate, compile, and transform Numpy code";
100 homepage = "https://github.com/google/jax";
101 license = licenses.asl20;
102 maintainers = with maintainers; [ samuela ];
103 };
104}