at 22.05-pre 1.5 kB view raw
1{ lib 2, absl-py 3, buildPythonPackage 4, fetchFromGitHub 5, jaxlib 6, numpy 7, opt-einsum 8, pytestCheckHook 9, pythonOlder 10, scipy 11, typing-extensions 12}: 13 14buildPythonPackage rec { 15 pname = "jax"; 16 version = "0.2.24"; 17 format = "setuptools"; 18 19 disabled = pythonOlder "3.7"; 20 21 src = fetchFromGitHub { 22 owner = "google"; 23 repo = pname; 24 rev = "jax-v${version}"; 25 sha256 = "1mmn1m4mprpwqlb1smjfdy3f74zm9p3l9dhhn25x6jrcj2cgc5pi"; 26 }; 27 28 # jaxlib is _not_ included in propagatedBuildInputs because there are 29 # different versions of jaxlib depending on the desired target hardware. The 30 # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the 31 # CPU wheel is packaged. 32 propagatedBuildInputs = [ 33 absl-py 34 numpy 35 opt-einsum 36 scipy 37 typing-extensions 38 ]; 39 40 checkInputs = [ 41 jaxlib 42 pytestCheckHook 43 ]; 44 45 # NOTE: Don't run the tests in the expiremental directory as they require flax 46 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 47 # Not a big deal, this is how the JAX docs suggest running the test suite 48 # anyhow. 49 pytestFlagsArray = [ 50 "-W ignore::DeprecationWarning" 51 "tests/" 52 ]; 53 54 pythonImportsCheck = [ 55 "jax" 56 ]; 57 58 meta = with lib; { 59 description = "Differentiate, compile, and transform Numpy code"; 60 homepage = "https://github.com/google/jax"; 61 license = licenses.asl20; 62 maintainers = with maintainers; [ samuela ]; 63 }; 64}