1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
2# backend will require some additional work. Those wheels are located here:
3# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
4
5# For future reference, the easiest way to test the GPU backend is to run
6# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }"
7# export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
8# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
9# python -c "from jax import random; random.PRNGKey(0)"
10# python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
11# There's no convenient way to test the GPU backend in the derivation since the
12# nix build environment blocks access to the GPU. See also:
13# * https://github.com/google/jax/issues/971#issuecomment-508216439
14# * https://github.com/google/jax/issues/5723#issuecomment-913038780
15
16{ absl-py
17, addOpenGLRunpath
18, autoPatchelfHook
19, buildPythonPackage
20, config
21, fetchPypi
22, fetchurl
23, flatbuffers
24, jaxlib-build
25, lib
26, ml-dtypes
27, python
28, scipy
29, stdenv
30 # Options:
31, cudaSupport ? config.cudaSupport
32, cudaPackages ? {}
33}:
34
35let
36 inherit (cudaPackages) cudatoolkit cudnn;
37in
38
39assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;
40
41let
42 version = "0.4.16";
43
44 inherit (python) pythonVersion;
45
46 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
47 # official instructions recommend installing CPU-only versions via PyPI.
48 cpuSrcs =
49 let
50 getSrcFromPypi = { platform, hash }: fetchPypi {
51 inherit version platform hash;
52 pname = "jaxlib";
53 format = "wheel";
54 # See the `disabled` attr comment below.
55 dist = "cp310";
56 python = "cp310";
57 abi = "cp310";
58 };
59 in
60 {
61 "x86_64-linux" = getSrcFromPypi {
62 platform = "manylinux2014_x86_64";
63 hash = "sha256-4XyaDnKEMhAbfPEvN3RCDEjXTWbOL6tWrTlyYeiboVs=";
64 };
65 "aarch64-darwin" = getSrcFromPypi {
66 platform = "macosx_11_0_arm64";
67 hash = "sha256-IG2pCui/Yj+LDMbQwBVlu7yl2llqnaxMzz/MtBvBr6U=";
68 };
69 "x86_64-darwin" = getSrcFromPypi {
70 platform = "macosx_10_14_x86_64";
71 hash = "sha256-x5DqsmHqEb7Dl7dnxT5N0l30GKt5OPZpq3HGX9MFKmo=";
72 };
73 };
74
75
76 # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
77 # When upgrading, you can get these hashes from prefetch.sh. See
78 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
79 gpuSrc = fetchurl {
80 url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
81 hash = "sha256-eLOprP2kv6roodwRKZXVZFQCD1wC26TSTEDJBjMu/Uo=";
82 };
83
84in
85buildPythonPackage {
86 pname = "jaxlib";
87 inherit version;
88 format = "wheel";
89
90 disabled = !(pythonVersion == "3.10");
91
92 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
93 src =
94 if !cudaSupport then
95 (
96 cpuSrcs."${stdenv.hostPlatform.system}"
97 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
98 ) else gpuSrc;
99
100 # Prebuilt wheels are dynamically linked against things that nix can't find.
101 # Run `autoPatchelfHook` to automagically fix them.
102 nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
103 ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
104 # Dynamic link dependencies
105 buildInputs = [ stdenv.cc.cc.lib ];
106
107 # jaxlib contains shared libraries that open other shared libraries via dlopen
108 # and these implicit dependencies are not recognized by ldd or
109 # autoPatchelfHook. That means we need to sneak them into rpath. This step
110 # must be done after autoPatchelfHook and the automatic stripping of
111 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
112 # patchPhase. Dependencies:
113 # * libcudart.so.11.0 -> cudatoolkit_11.lib
114 # * libcublas.so.11 -> cudatoolkit_11
115 # * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib
116 preInstallCheck = lib.optional cudaSupport ''
117 shopt -s globstar
118
119 addOpenGLRunpath $out/**/*.so
120
121 for file in $out/**/*.so; do
122 rpath=$(patchelf --print-rpath $file)
123 # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
124 # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
125 patchelf --set-rpath "$rpath:${cudatoolkit}/lib:${lib.makeLibraryPath [ cudatoolkit.lib cudnn ]}" $file
126 done
127 '';
128
129 propagatedBuildInputs = [
130 absl-py
131 flatbuffers
132 ml-dtypes
133 scipy
134 ];
135
136 # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
137 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
138 # more info.
139 postInstall = lib.optional cudaSupport ''
140 mkdir -p $out/bin
141 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
142 '';
143
144 inherit (jaxlib-build) pythonImportsCheck;
145
146 meta = with lib; {
147 description = "XLA library for JAX";
148 homepage = "https://github.com/google/jax";
149 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
150 license = licenses.asl20;
151 maintainers = with maintainers; [ samuela ];
152 platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
153 };
154}