1{ lib
2, absl-py
3, buildPythonPackage
4, fetchFromGitHub
5, jaxlib
6, numpy
7, opt-einsum
8, pytestCheckHook
9, pythonOlder
10, scipy
11, typing-extensions
12}:
13
14buildPythonPackage rec {
15 pname = "jax";
16 version = "0.2.24";
17 format = "setuptools";
18
19 disabled = pythonOlder "3.7";
20
21 src = fetchFromGitHub {
22 owner = "google";
23 repo = pname;
24 rev = "jax-v${version}";
25 sha256 = "1mmn1m4mprpwqlb1smjfdy3f74zm9p3l9dhhn25x6jrcj2cgc5pi";
26 };
27
28 # jaxlib is _not_ included in propagatedBuildInputs because there are
29 # different versions of jaxlib depending on the desired target hardware. The
30 # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
31 # CPU wheel is packaged.
32 propagatedBuildInputs = [
33 absl-py
34 numpy
35 opt-einsum
36 scipy
37 typing-extensions
38 ];
39
40 checkInputs = [
41 jaxlib
42 pytestCheckHook
43 ];
44
45 # NOTE: Don't run the tests in the expiremental directory as they require flax
46 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
47 # Not a big deal, this is how the JAX docs suggest running the test suite
48 # anyhow.
49 pytestFlagsArray = [
50 "-W ignore::DeprecationWarning"
51 "tests/"
52 ];
53
54 pythonImportsCheck = [
55 "jax"
56 ];
57
58 meta = with lib; {
59 description = "Differentiate, compile, and transform Numpy code";
60 homepage = "https://github.com/google/jax";
61 license = licenses.asl20;
62 maintainers = with maintainers; [ samuela ];
63 };
64}