Merge pull request #162703 from samuela/samuela/jaxlib-bin

python3Packages.jaxlib-bin: add support for python 3.10 and cudnn >=8.2

authored by Samuel Ainsworth and committed by GitHub c5e29efa bc265bf8

+57 -27
+50 -27
pkgs/development/python-modules/jaxlib/bin.nix
··· 24 , flatbuffers 25 , isPy39 26 , lib 27 , scipy 28 , stdenv 29 # Options: 30 , cudaSupport ? config.cudaSupport or false 31 }: 32 33 - # Note that these values are tied to the specific version of the GPU wheel that 34 - # we fetch. When updating, try to go for the latest possible versions that are 35 - # still compatible with the cudatoolkit and cudnn versions available in nixpkgs. 36 assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; 37 assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; 38 39 let 40 - device = if cudaSupport then "gpu" else "cpu"; 41 - in 42 - buildPythonPackage rec { 43 - pname = "jaxlib"; 44 version = "0.3.0"; 45 - format = "wheel"; 46 47 - # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting 48 - # all of them is a pain, so we focus on 3.9, the current nixpkgs python3 49 - # version. 50 - disabled = !isPy39; 51 52 - # Find new releases at https://storage.googleapis.com/jax-releases. 53 - src = { 54 - cpu = fetchurl { 55 url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; 56 - sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01"; 57 }; 58 - gpu = fetchurl { 59 - # Note that there's also a release targeting cuDNN 8.2, but unfortunately 60 - # we don't yet have that packaged at the time of writing (02/03/2022). 61 - # Check pkgs/development/libraries/science/math/cudnn/default.nix for more 62 - # details. 63 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; 64 - sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8"; 65 - 66 - # This is what the cuDNN 8.2 download looks like for future reference: 67 - # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; 68 - # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb"; 69 }; 70 - }.${device}; 71 72 # Prebuilt wheels are dynamically linked against things that nix can't find. 73 # Run `autoPatchelfHook` to automagically fix them.
··· 24 , flatbuffers 25 , isPy39 26 , lib 27 + , python 28 , scipy 29 , stdenv 30 # Options: 31 , cudaSupport ? config.cudaSupport or false 32 }: 33 34 + # There are no jaxlib wheels targeting cudnn <8.0.5, and although there are 35 + # wheels for cudatoolkit <11.1, we don't support them. 36 assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; 37 assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; 38 39 let 40 version = "0.3.0"; 41 42 + pythonVersion = python.pythonVersion; 43 44 + # Find new releases at https://storage.googleapis.com/jax-releases. When 45 + # upgrading, you can get these hashes from prefetch.sh. 46 + cpuSrcs = { 47 + "3.9" = fetchurl { 48 url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; 49 + hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; 50 }; 51 + "3.10" = fetchurl { 52 + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; 53 + hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; 54 + }; 55 + }; 56 + 57 + gpuSrcs = { 58 + "3.9-805" = fetchurl { 59 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; 60 + hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; 61 + }; 62 + "3.9-82" = fetchurl { 63 + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; 64 + hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; 65 + }; 66 + "3.10-805" = fetchurl { 67 + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; 68 + hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; 69 + }; 70 + "3.10-82" = fetchurl { 71 + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; 72 + hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; 73 }; 74 + }; 75 + in 76 + buildPythonPackage rec { 77 + pname = "jaxlib"; 78 + inherit version; 79 + format = "wheel"; 80 + 81 + # At the time of writing (2022-03-03), there are releases for <=3.10. 82 + # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs 83 + # python3 version, and 3.10. 84 + disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); 85 + 86 + src = 87 + if !cudaSupport then cpuSrcs."${pythonVersion}" else 88 + let 89 + # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and 90 + # 8.2. Try to use 8.2 whenever possible. 91 + cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; 92 + in 93 + gpuSrcs."${pythonVersion}-${cudnnVersion}"; 94 95 # Prebuilt wheels are dynamically linked against things that nix can't find. 96 # Run `autoPatchelfHook` to automagically fix them.
+7
pkgs/development/python-modules/jaxlib/prefetch.sh
···
··· 1 + version="$1" 2 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" 3 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" 4 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" 5 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" 6 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" 7 + nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"