Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU 2# backend will require some additional work. Those wheels are located here: 3# https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5# For future reference, the easiest way to test the GPU backend is to run 6# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }" 7# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 8# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" 9# python -c "from jax import random; random.PRNGKey(0)" 10# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x" 11# There's no convenient way to test the GPU backend in the derivation since the 12# nix build environment blocks access to the GPU. See also: 13# * https://github.com/google/jax/issues/971#issuecomment-508216439 14# * https://github.com/google/jax/issues/5723#issuecomment-913038780 15 16{ absl-py 17, addOpenGLRunpath 18, autoPatchelfHook 19, buildPythonPackage 20, config 21, fetchPypi 22, fetchurl 23, flatbuffers 24, jaxlib-build 25, lib 26, ml-dtypes 27, python 28, scipy 29, stdenv 30 # Options: 31, cudaSupport ? config.cudaSupport 32, cudaPackages ? {} 33}: 34 35let 36 inherit (cudaPackages) cudatoolkit cudnn; 37in 38 39assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; 40 41let 42 version = "0.4.16"; 43 44 inherit (python) pythonVersion; 45 46 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the 47 # official instructions recommend installing CPU-only versions via PyPI. 48 cpuSrcs = 49 let 50 getSrcFromPypi = { platform, hash }: fetchPypi { 51 inherit version platform hash; 52 pname = "jaxlib"; 53 format = "wheel"; 54 # See the `disabled` attr comment below. 55 dist = "cp310"; 56 python = "cp310"; 57 abi = "cp310"; 58 }; 59 in 60 { 61 "x86_64-linux" = getSrcFromPypi { 62 platform = "manylinux2014_x86_64"; 63 hash = "sha256-4XyaDnKEMhAbfPEvN3RCDEjXTWbOL6tWrTlyYeiboVs="; 64 }; 65 "aarch64-darwin" = getSrcFromPypi { 66 platform = "macosx_11_0_arm64"; 67 hash = "sha256-IG2pCui/Yj+LDMbQwBVlu7yl2llqnaxMzz/MtBvBr6U="; 68 }; 69 "x86_64-darwin" = getSrcFromPypi { 70 platform = "macosx_10_14_x86_64"; 71 hash = "sha256-x5DqsmHqEb7Dl7dnxT5N0l30GKt5OPZpq3HGX9MFKmo="; 72 }; 73 }; 74 75 76 # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. 77 # When upgrading, you can get these hashes from prefetch.sh. See 78 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. 79 gpuSrc = fetchurl { 80 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; 81 hash = "sha256-eLOprP2kv6roodwRKZXVZFQCD1wC26TSTEDJBjMu/Uo="; 82 }; 83 84in 85buildPythonPackage { 86 pname = "jaxlib"; 87 inherit version; 88 format = "wheel"; 89 90 disabled = !(pythonVersion == "3.10"); 91 92 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. 93 src = 94 if !cudaSupport then 95 ( 96 cpuSrcs."${stdenv.hostPlatform.system}" 97 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") 98 ) else gpuSrc; 99 100 # Prebuilt wheels are dynamically linked against things that nix can't find. 101 # Run `autoPatchelfHook` to automagically fix them. 102 nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] 103 ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; 104 # Dynamic link dependencies 105 buildInputs = [ stdenv.cc.cc.lib ]; 106 107 # jaxlib contains shared libraries that open other shared libraries via dlopen 108 # and these implicit dependencies are not recognized by ldd or 109 # autoPatchelfHook. That means we need to sneak them into rpath. This step 110 # must be done after autoPatchelfHook and the automatic stripping of 111 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 112 # patchPhase. Dependencies: 113 # * libcudart.so.11.0 -> cudatoolkit_11.lib 114 # * libcublas.so.11 -> cudatoolkit_11 115 # * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib 116 preInstallCheck = lib.optional cudaSupport '' 117 shopt -s globstar 118 119 addOpenGLRunpath $out/**/*.so 120 121 for file in $out/**/*.so; do 122 rpath=$(patchelf --print-rpath $file) 123 # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to 124 # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib. 125 patchelf --set-rpath "$rpath:${cudatoolkit}/lib:${lib.makeLibraryPath [ cudatoolkit.lib cudnn ]}" $file 126 done 127 ''; 128 129 propagatedBuildInputs = [ 130 absl-py 131 flatbuffers 132 ml-dtypes 133 scipy 134 ]; 135 136 # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. 137 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for 138 # more info. 139 postInstall = lib.optional cudaSupport '' 140 mkdir -p $out/bin 141 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas 142 ''; 143 144 inherit (jaxlib-build) pythonImportsCheck; 145 146 meta = with lib; { 147 description = "XLA library for JAX"; 148 homepage = "https://github.com/google/jax"; 149 sourceProvenance = with sourceTypes; [ binaryNativeCode ]; 150 license = licenses.asl20; 151 maintainers = with maintainers; [ samuela ]; 152 platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; 153 }; 154}