Merge pull request #196977 from samuela/samuela/jax

JAX upgrades

authored by Samuel Ainsworth and committed by GitHub 068b4774 9ad41618

+35 -61
+4 -5
pkgs/development/python-modules/jax/default.nix
··· 21 in 22 buildPythonPackage rec { 23 pname = "jax"; 24 - version = "0.3.16"; 25 format = "setuptools"; 26 27 disabled = pythonOlder "3.7"; ··· 30 owner = "google"; 31 repo = pname; 32 rev = "jax-v${version}"; 33 - hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I="; 34 }; 35 36 # jaxlib is _not_ included in propagatedBuildInputs because there are ··· 92 "tests/sparse_test.py" 93 ]; 94 95 - pythonImportsCheck = [ 96 - "jax" 97 - ]; 98 99 meta = with lib; { 100 description = "Differentiate, compile, and transform Numpy code";
··· 21 in 22 buildPythonPackage rec { 23 pname = "jax"; 24 + version = "0.3.23"; 25 format = "setuptools"; 26 27 disabled = pythonOlder "3.7"; ··· 30 owner = "google"; 31 repo = pname; 32 rev = "jax-v${version}"; 33 + hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4="; 34 }; 35 36 # jaxlib is _not_ included in propagatedBuildInputs because there are ··· 92 "tests/sparse_test.py" 93 ]; 94 95 + # As of 0.3.22, `import jax` does not work without jaxlib being installed. 96 + pythonImportsCheck = [ ]; 97 98 meta = with lib; { 99 description = "Differentiate, compile, and transform Numpy code";
+23 -44
pkgs/development/python-modules/jaxlib/bin.nix
··· 3 # https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5 # For future reference, the easiest way to test the GPU backend is to run 6 - # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" 7 # export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 8 # python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" 9 # python -c "from jax import random; random.PRNGKey(0)" ··· 35 inherit (cudaPackages) cudatoolkit cudnn; 36 in 37 38 - # There are no jaxlib wheels targeting cudnn <8.0.5, and although there are 39 - # wheels for cudatoolkit <11.1, we don't support them. 40 assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; 41 - assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; 42 43 let 44 - version = "0.3.0"; 45 46 pythonVersion = python.pythonVersion; 47 48 - # Find new releases at https://storage.googleapis.com/jax-releases. When 49 - # upgrading, you can get these hashes from prefetch.sh. 50 cpuSrcs = { 51 - "3.9" = fetchurl { 52 - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; 53 - hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; 54 }; 55 - "3.10" = fetchurl { 56 - url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; 57 - hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; 58 }; 59 }; 60 61 - gpuSrcs = { 62 - "3.9-805" = fetchurl { 63 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; 64 - hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; 65 - }; 66 - "3.9-82" = fetchurl { 67 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; 68 - hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; 69 - }; 70 - "3.10-805" = fetchurl { 71 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; 72 - hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; 73 - }; 74 - "3.10-82" = fetchurl { 75 - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; 76 - hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; 77 - }; 78 }; 79 in 80 buildPythonPackage rec { ··· 82 inherit version; 83 format = "wheel"; 84 85 - # At the time of writing (2022-03-03), there are releases for <=3.10. 86 - # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs 87 - # python3 version, and 3.10. 88 - disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); 89 90 - src = 91 - if !cudaSupport then cpuSrcs."${pythonVersion}" else 92 - let 93 - # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and 94 - # 8.2. Try to use 8.2 whenever possible. 95 - cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; 96 - in 97 - gpuSrcs."${pythonVersion}-${cudnnVersion}"; 98 99 # Prebuilt wheels are dynamically linked against things that nix can't find. 100 # Run `autoPatchelfHook` to automagically fix them. 101 - nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; 102 # Dynamic link dependencies 103 buildInputs = [ stdenv.cc.cc ]; 104 ··· 142 sourceProvenance = with sourceTypes; [ binaryNativeCode ]; 143 license = licenses.asl20; 144 maintainers = with maintainers; [ samuela ]; 145 - platforms = [ "x86_64-linux" ]; 146 }; 147 }
··· 3 # https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5 # For future reference, the easiest way to test the GPU backend is to run 6 + # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }" 7 # export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 8 # python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" 9 # python -c "from jax import random; random.PRNGKey(0)" ··· 35 inherit (cudaPackages) cudatoolkit cudnn; 36 in 37 38 assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; 39 + assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; 40 41 let 42 + version = "0.3.22"; 43 44 pythonVersion = python.pythonVersion; 45 46 + # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. 47 + # When upgrading, you can get these hashes from prefetch.sh. See 48 + # https://github.com/google/jax/issues/12879 as to why this specific URL is 49 + # the correct index. 50 cpuSrcs = { 51 + "x86_64-linux" = fetchurl { 52 + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; 53 + hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4="; 54 }; 55 + "aarch64-darwin" = fetchurl { 56 + url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; 57 + hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0="; 58 }; 59 }; 60 61 + gpuSrc = fetchurl { 62 + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; 63 + hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs="; 64 }; 65 in 66 buildPythonPackage rec { ··· 68 inherit version; 69 format = "wheel"; 70 71 + # At the time of writing (2022-10-19), there are releases for <=3.10. 72 + # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs 73 + # python version. 74 + disabled = !(pythonVersion == "3.10"); 75 76 + src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc; 77 78 # Prebuilt wheels are dynamically linked against things that nix can't find. 79 # Run `autoPatchelfHook` to automagically fix them. 80 + nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; 81 # Dynamic link dependencies 82 buildInputs = [ stdenv.cc.cc ]; 83 ··· 121 sourceProvenance = with sourceTypes; [ binaryNativeCode ]; 122 license = licenses.asl20; 123 maintainers = with maintainers; [ samuela ]; 124 + platforms = [ "aarch64-darwin" "x86_64-linux" ]; 125 }; 126 }
+8 -6
pkgs/development/python-modules/jaxlib/default.nix
··· 53 inherit (cudaPackages) cudatoolkit cudnn nccl; 54 55 pname = "jaxlib"; 56 - version = "0.3.15"; 57 58 meta = with lib; { 59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 96 owner = "google"; 97 repo = "jax"; 98 rev = "${pname}-v${version}"; 99 - sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s="; 100 }; 101 102 nativeBuildInputs = [ ··· 235 fetchAttrs = { 236 sha256 = 237 if cudaSupport then 238 - "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk=" 239 else if stdenv.isDarwin then 240 - "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34=" 241 else 242 - "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A="; 243 }; 244 245 buildAttrs = { ··· 293 inherit meta pname version; 294 format = "wheel"; 295 296 - src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl"; 297 298 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. 299 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
··· 53 inherit (cudaPackages) cudatoolkit cudnn nccl; 54 55 pname = "jaxlib"; 56 + version = "0.3.22"; 57 58 meta = with lib; { 59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; ··· 96 owner = "google"; 97 repo = "jax"; 98 rev = "${pname}-v${version}"; 99 + hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs="; 100 }; 101 102 nativeBuildInputs = [ ··· 235 fetchAttrs = { 236 sha256 = 237 if cudaSupport then 238 + "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ=" 239 else if stdenv.isDarwin then 240 + "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM=" 241 else 242 + "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM="; 243 }; 244 245 buildAttrs = { ··· 293 inherit meta pname version; 294 format = "wheel"; 295 296 + src = 297 + let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; 298 + in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 299 300 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. 301 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
-6
pkgs/top-level/python-packages.nix
··· 4780 4781 jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { 4782 cudaSupport = pkgs.config.cudaSupport or false; 4783 - # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we 4784 - # pin to `cudaPackages_11_6` instead. 4785 - cudaPackages = pkgs.cudaPackages_11_6; 4786 }; 4787 4788 jaxlib-build = callPackage ../development/python-modules/jaxlib rec { ··· 4792 }; 4793 # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. 4794 cudaSupport = pkgs.config.cudaSupport or false; 4795 - # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we 4796 - # pin to `cudaPackages_11_6` instead. 4797 - cudaPackages = pkgs.cudaPackages_11_6; 4798 IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; 4799 protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 4800 };
··· 4780 4781 jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { 4782 cudaSupport = pkgs.config.cudaSupport or false; 4783 }; 4784 4785 jaxlib-build = callPackage ../development/python-modules/jaxlib rec { ··· 4789 }; 4790 # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. 4791 cudaSupport = pkgs.config.cudaSupport or false; 4792 IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; 4793 protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 4794 };