Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU 2# backend will require some additional work. Those wheels are located here: 3# https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5# See `python3Packages.jax.passthru` for CUDA tests. 6 7{ 8 absl-py, 9 autoAddDriverRunpath, 10 autoPatchelfHook, 11 buildPythonPackage, 12 config, 13 fetchPypi, 14 fetchurl, 15 flatbuffers, 16 jaxlib-build, 17 lib, 18 ml-dtypes, 19 python, 20 scipy, 21 stdenv, 22 # Options: 23 cudaSupport ? config.cudaSupport, 24 cudaPackages, 25}: 26 27let 28 inherit (cudaPackages) cudaVersion; 29 30 version = "0.4.28"; 31 32 inherit (python) pythonVersion; 33 34 cudaLibPath = lib.makeLibraryPath ( 35 with cudaPackages; 36 [ 37 (lib.getLib cuda_cudart) # libcudart.so 38 (lib.getLib cuda_cupti) # libcupti.so 39 (lib.getLib cudnn) # libcudnn.so 40 (lib.getLib libcufft) # libcufft.so 41 (lib.getLib libcusolver) # libcusolver.so 42 (lib.getLib libcusparse) # libcusparse.so 43 ] 44 ); 45 46 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the 47 # official instructions recommend installing CPU-only versions via PyPI. 48 cpuSrcs = 49 let 50 getSrcFromPypi = 51 { 52 platform, 53 dist, 54 hash, 55 }: 56 fetchPypi { 57 inherit 58 version 59 platform 60 dist 61 hash 62 ; 63 pname = "jaxlib"; 64 format = "wheel"; 65 # See the `disabled` attr comment below. 66 python = dist; 67 abi = dist; 68 }; 69 in 70 { 71 "3.9-x86_64-linux" = getSrcFromPypi { 72 platform = "manylinux2014_x86_64"; 73 dist = "cp39"; 74 hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw="; 75 }; 76 "3.9-aarch64-darwin" = getSrcFromPypi { 77 platform = "macosx_11_0_arm64"; 78 dist = "cp39"; 79 hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw="; 80 }; 81 "3.9-x86_64-darwin" = getSrcFromPypi { 82 platform = "macosx_10_14_x86_64"; 83 dist = "cp39"; 84 hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c="; 85 }; 86 87 "3.10-x86_64-linux" = getSrcFromPypi { 88 platform = "manylinux2014_x86_64"; 89 dist = "cp310"; 90 hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps="; 91 }; 92 "3.10-aarch64-darwin" = getSrcFromPypi { 93 platform = "macosx_11_0_arm64"; 94 dist = "cp310"; 95 hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk="; 96 }; 97 "3.10-x86_64-darwin" = getSrcFromPypi { 98 platform = "macosx_10_14_x86_64"; 99 dist = "cp310"; 100 hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY="; 101 }; 102 103 "3.11-x86_64-linux" = getSrcFromPypi { 104 platform = "manylinux2014_x86_64"; 105 dist = "cp311"; 106 hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; 107 }; 108 "3.11-aarch64-darwin" = getSrcFromPypi { 109 platform = "macosx_11_0_arm64"; 110 dist = "cp311"; 111 hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck="; 112 }; 113 "3.11-x86_64-darwin" = getSrcFromPypi { 114 platform = "macosx_10_14_x86_64"; 115 dist = "cp311"; 116 hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU="; 117 }; 118 119 "3.12-x86_64-linux" = getSrcFromPypi { 120 platform = "manylinux2014_x86_64"; 121 dist = "cp312"; 122 hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40="; 123 }; 124 "3.12-aarch64-darwin" = getSrcFromPypi { 125 platform = "macosx_11_0_arm64"; 126 dist = "cp312"; 127 hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10="; 128 }; 129 "3.12-x86_64-darwin" = getSrcFromPypi { 130 platform = "macosx_10_14_x86_64"; 131 dist = "cp312"; 132 hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A="; 133 }; 134 }; 135 136 # Note that the prebuilt jaxlib binary requires specific version of CUDA to 137 # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11 138 # jaxlib binaries only works with CUDA 11.8. This is why we need to find a 139 # binary that matches the provided cudaVersion. 140 gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}"; 141 142 # Find new releases at https://storage.googleapis.com/jax-releases 143 # When upgrading, you can get these hashes from prefetch.sh. See 144 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. 145 gpuSrcs = { 146 "cuda12.2-3.9" = fetchurl { 147 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; 148 hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw="; 149 }; 150 "cuda12.2-3.10" = fetchurl { 151 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; 152 hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ="; 153 }; 154 "cuda12.2-3.11" = fetchurl { 155 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; 156 hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o="; 157 }; 158 "cuda12.2-3.12" = fetchurl { 159 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; 160 hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU="; 161 }; 162 }; 163in 164buildPythonPackage { 165 pname = "jaxlib"; 166 inherit version; 167 format = "wheel"; 168 169 disabled = 170 !( 171 pythonVersion == "3.9" 172 || pythonVersion == "3.10" 173 || pythonVersion == "3.11" 174 || pythonVersion == "3.12" 175 ); 176 177 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. 178 src = 179 if !cudaSupport then 180 (cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" 181 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") 182 ) 183 else 184 gpuSrcs."${gpuSrcVersionString}"; 185 186 # Prebuilt wheels are dynamically linked against things that nix can't find. 187 # Run `autoPatchelfHook` to automagically fix them. 188 nativeBuildInputs = 189 lib.optionals stdenv.isLinux [ autoPatchelfHook ] 190 ++ lib.optionals cudaSupport [ autoAddDriverRunpath ]; 191 # Dynamic link dependencies 192 buildInputs = [ stdenv.cc.cc.lib ]; 193 194 # jaxlib contains shared libraries that open other shared libraries via dlopen 195 # and these implicit dependencies are not recognized by ldd or 196 # autoPatchelfHook. That means we need to sneak them into rpath. This step 197 # must be done after autoPatchelfHook and the automatic stripping of 198 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 199 # patchPhase. 200 preInstallCheck = lib.optional cudaSupport '' 201 shopt -s globstar 202 203 for file in $out/**/*.so; do 204 patchelf --add-rpath "${cudaLibPath}" "$file" 205 done 206 ''; 207 208 propagatedBuildInputs = [ 209 absl-py 210 flatbuffers 211 ml-dtypes 212 scipy 213 ]; 214 215 # jaxlib looks for ptxas at runtime, eg when running `jax.random.PRNGKey(0)`. 216 # Linking into $out is the least bad solution. See 217 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 218 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 219 # for more info. 220 postInstall = lib.optional cudaSupport '' 221 mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin 222 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas 223 ''; 224 225 inherit (jaxlib-build) pythonImportsCheck; 226 227 meta = with lib; { 228 description = "Prebuilt jaxlib backend from PyPi"; 229 homepage = "https://github.com/google/jax"; 230 sourceProvenance = with sourceTypes; [ binaryNativeCode ]; 231 license = licenses.asl20; 232 maintainers = with maintainers; [ samuela ]; 233 platforms = [ 234 "aarch64-darwin" 235 "x86_64-linux" 236 "x86_64-darwin" 237 ]; 238 broken = 239 !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1") 240 || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2") 241 || !(cudaSupport -> stdenv.isLinux) 242 || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")) 243 # Fails at pythonImportsCheckPhase: 244 # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4 245 # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c 246 # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))' 247 || (stdenv.isDarwin && stdenv.isx86_64); 248 }; 249}