···21in
22buildPythonPackage rec {
23 pname = "jax";
24- version = "0.3.16";
25 format = "setuptools";
2627 disabled = pythonOlder "3.7";
···30 owner = "google";
31 repo = pname;
32 rev = "jax-v${version}";
33- hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I=";
34 };
3536 # jaxlib is _not_ included in propagatedBuildInputs because there are
···92 "tests/sparse_test.py"
93 ];
9495- pythonImportsCheck = [
96- "jax"
97- ];
9899 meta = with lib; {
100 description = "Differentiate, compile, and transform Numpy code";
···21in
22buildPythonPackage rec {
23 pname = "jax";
24+ version = "0.3.23";
25 format = "setuptools";
2627 disabled = pythonOlder "3.7";
···30 owner = "google";
31 repo = pname;
32 rev = "jax-v${version}";
33+ hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4=";
34 };
3536 # jaxlib is _not_ included in propagatedBuildInputs because there are
···92 "tests/sparse_test.py"
93 ];
9495+ # As of 0.3.22, `import jax` does not work without jaxlib being installed.
96+ pythonImportsCheck = [ ];
09798 meta = with lib; {
99 description = "Differentiate, compile, and transform Numpy code";
+23-44
pkgs/development/python-modules/jaxlib/bin.nix
···3# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
45# For future reference, the easiest way to test the GPU backend is to run
6-# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
7# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
8# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
9# python -c "from jax import random; random.PRNGKey(0)"
···35 inherit (cudaPackages) cudatoolkit cudnn;
36in
3738-# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are
39-# wheels for cudatoolkit <11.1, we don't support them.
40assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
41-assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
4243let
44- version = "0.3.0";
4546 pythonVersion = python.pythonVersion;
4748- # Find new releases at https://storage.googleapis.com/jax-releases. When
49- # upgrading, you can get these hashes from prefetch.sh.
0050 cpuSrcs = {
51- "3.9" = fetchurl {
52- url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
53- hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q=";
54 };
55- "3.10" = fetchurl {
56- url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl";
57- hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ=";
58 };
59 };
6061- gpuSrcs = {
62- "3.9-805" = fetchurl {
63- url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
64- hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw=";
65- };
66- "3.9-82" = fetchurl {
67- url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
68- hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4=";
69- };
70- "3.10-805" = fetchurl {
71- url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl";
72- hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U=";
73- };
74- "3.10-82" = fetchurl {
75- url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl";
76- hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go=";
77- };
78 };
79in
80buildPythonPackage rec {
···82 inherit version;
83 format = "wheel";
8485- # At the time of writing (2022-03-03), there are releases for <=3.10.
86- # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs
87- # python3 version, and 3.10.
88- disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10");
8990- src =
91- if !cudaSupport then cpuSrcs."${pythonVersion}" else
92- let
93- # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and
94- # 8.2. Try to use 8.2 whenever possible.
95- cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805";
96- in
97- gpuSrcs."${pythonVersion}-${cudnnVersion}";
9899 # Prebuilt wheels are dynamically linked against things that nix can't find.
100 # Run `autoPatchelfHook` to automagically fix them.
101- nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
102 # Dynamic link dependencies
103 buildInputs = [ stdenv.cc.cc ];
104···142 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
143 license = licenses.asl20;
144 maintainers = with maintainers; [ samuela ];
145- platforms = [ "x86_64-linux" ];
146 };
147}
···3# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
45# For future reference, the easiest way to test the GPU backend is to run
6+# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }"
7# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
8# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
9# python -c "from jax import random; random.PRNGKey(0)"
···35 inherit (cudaPackages) cudatoolkit cudnn;
36in
370038assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
39+assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
4041let
42+ version = "0.3.22";
4344 pythonVersion = python.pythonVersion;
4546+ # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
47+ # When upgrading, you can get these hashes from prefetch.sh. See
48+ # https://github.com/google/jax/issues/12879 as to why this specific URL is
49+ # the correct index.
50 cpuSrcs = {
51+ "x86_64-linux" = fetchurl {
52+ url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
53+ hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
54 };
55+ "aarch64-darwin" = fetchurl {
56+ url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
57+ hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
58 };
59 };
6061+ gpuSrc = fetchurl {
62+ url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
63+ hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
0000000000000064 };
65in
66buildPythonPackage rec {
···68 inherit version;
69 format = "wheel";
7071+ # At the time of writing (2022-10-19), there are releases for <=3.10.
72+ # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
73+ # python version.
74+ disabled = !(pythonVersion == "3.10");
7576+ src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;
00000007778 # Prebuilt wheels are dynamically linked against things that nix can't find.
79 # Run `autoPatchelfHook` to automagically fix them.
80+ nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
81 # Dynamic link dependencies
82 buildInputs = [ stdenv.cc.cc ];
83···121 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
122 license = licenses.asl20;
123 maintainers = with maintainers; [ samuela ];
124+ platforms = [ "aarch64-darwin" "x86_64-linux" ];
125 };
126}
···53 inherit (cudaPackages) cudatoolkit cudnn nccl;
5455 pname = "jaxlib";
56- version = "0.3.15";
5758 meta = with lib; {
59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
···96 owner = "google";
97 repo = "jax";
98 rev = "${pname}-v${version}";
99- sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
100 };
101102 nativeBuildInputs = [
···235 fetchAttrs = {
236 sha256 =
237 if cudaSupport then
238- "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
239 else if stdenv.isDarwin then
240- "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
241 else
242- "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
243 };
244245 buildAttrs = {
···293 inherit meta pname version;
294 format = "wheel";
295296- src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl";
00297298 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
299 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
···53 inherit (cudaPackages) cudatoolkit cudnn nccl;
5455 pname = "jaxlib";
56+ version = "0.3.22";
5758 meta = with lib; {
59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
···96 owner = "google";
97 repo = "jax";
98 rev = "${pname}-v${version}";
99+ hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
100 };
101102 nativeBuildInputs = [
···235 fetchAttrs = {
236 sha256 =
237 if cudaSupport then
238+ "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ="
239 else if stdenv.isDarwin then
240+ "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM="
241 else
242+ "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM=";
243 };
244245 buildAttrs = {
···293 inherit meta pname version;
294 format = "wheel";
295296+ src =
297+ let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
298+ in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
299300 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
301 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
-6
pkgs/top-level/python-packages.nix
···47804781 jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
4782 cudaSupport = pkgs.config.cudaSupport or false;
4783- # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
4784- # pin to `cudaPackages_11_6` instead.
4785- cudaPackages = pkgs.cudaPackages_11_6;
4786 };
47874788 jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
···4792 };
4793 # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
4794 cudaSupport = pkgs.config.cudaSupport or false;
4795- # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
4796- # pin to `cudaPackages_11_6` instead.
4797- cudaPackages = pkgs.cudaPackages_11_6;
4798 IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
4799 protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
4800 };
···47804781 jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
4782 cudaSupport = pkgs.config.cudaSupport or false;
0004783 };
47844785 jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
···4789 };
4790 # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
4791 cudaSupport = pkgs.config.cudaSupport or false;
0004792 IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
4793 protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
4794 };