nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at 22.05 146 lines 5.8 kB view raw
1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU 2# backend will require some additional work. Those wheels are located here: 3# https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5# For future reference, the easiest way to test the GPU backend is to run 6# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" 7# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 8# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" 9# python -c "from jax import random; random.PRNGKey(0)" 10# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x" 11# There's no convenient way to test the GPU backend in the derivation since the 12# nix build environment blocks access to the GPU. See also: 13# * https://github.com/google/jax/issues/971#issuecomment-508216439 14# * https://github.com/google/jax/issues/5723#issuecomment-913038780 15 16{ absl-py 17, addOpenGLRunpath 18, autoPatchelfHook 19, buildPythonPackage 20, config 21, cudnn 22, fetchurl 23, flatbuffers 24, isPy39 25, lib 26, python 27, scipy 28, stdenv 29 # Options: 30, cudaSupport ? config.cudaSupport or false 31, cudaPackages ? {} 32}: 33 34let 35 inherit (cudaPackages) cudatoolkit cudnn; 36in 37 38# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are 39# wheels for cudatoolkit <11.1, we don't support them. 40assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; 41assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; 42 43let 44 version = "0.3.0"; 45 46 pythonVersion = python.pythonVersion; 47 48 # Find new releases at https://storage.googleapis.com/jax-releases. When 49 # upgrading, you can get these hashes from prefetch.sh. 50 cpuSrcs = { 51 "3.9" = fetchurl { 52 url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; 53 hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; 54 }; 55 "3.10" = fetchurl { 56 url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; 57 hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; 58 }; 59 }; 60 61 gpuSrcs = { 62 "3.9-805" = fetchurl { 63 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; 64 hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; 65 }; 66 "3.9-82" = fetchurl { 67 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; 68 hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; 69 }; 70 "3.10-805" = fetchurl { 71 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; 72 hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; 73 }; 74 "3.10-82" = fetchurl { 75 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; 76 hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; 77 }; 78 }; 79in 80buildPythonPackage rec { 81 pname = "jaxlib"; 82 inherit version; 83 format = "wheel"; 84 85 # At the time of writing (2022-03-03), there are releases for <=3.10. 86 # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs 87 # python3 version, and 3.10. 88 disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); 89 90 src = 91 if !cudaSupport then cpuSrcs."${pythonVersion}" else 92 let 93 # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and 94 # 8.2. Try to use 8.2 whenever possible. 95 cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; 96 in 97 gpuSrcs."${pythonVersion}-${cudnnVersion}"; 98 99 # Prebuilt wheels are dynamically linked against things that nix can't find. 100 # Run `autoPatchelfHook` to automagically fix them. 101 nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; 102 # Dynamic link dependencies 103 buildInputs = [ stdenv.cc.cc ]; 104 105 # jaxlib contains shared libraries that open other shared libraries via dlopen 106 # and these implicit dependencies are not recognized by ldd or 107 # autoPatchelfHook. That means we need to sneak them into rpath. This step 108 # must be done after autoPatchelfHook and the automatic stripping of 109 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 110 # patchPhase. Dependencies: 111 # * libcudart.so.11.0 -> cudatoolkit_11.lib 112 # * libcublas.so.11 -> cudatoolkit_11 113 # * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib 114 preInstallCheck = lib.optional cudaSupport '' 115 shopt -s globstar 116 117 addOpenGLRunpath $out/**/*.so 118 119 for file in $out/**/*.so; do 120 rpath=$(patchelf --print-rpath $file) 121 # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to 122 # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib. 123 patchelf --set-rpath "$rpath:${cudatoolkit}/lib:${lib.makeLibraryPath [ cudatoolkit.lib cudnn ]}" $file 124 done 125 ''; 126 127 propagatedBuildInputs = [ absl-py flatbuffers scipy ]; 128 129 # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. 130 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for 131 # more info. 132 postInstall = lib.optional cudaSupport '' 133 mkdir -p $out/bin 134 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas 135 ''; 136 137 pythonImportsCheck = [ "jaxlib" ]; 138 139 meta = with lib; { 140 description = "XLA library for JAX"; 141 homepage = "https://github.com/google/jax"; 142 license = licenses.asl20; 143 maintainers = with maintainers; [ samuela ]; 144 platforms = [ "x86_64-linux" ]; 145 }; 146}