1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
2# backend will require some additional work. Those wheels are located here:
3# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
4
5# See `python3Packages.jax.passthru` for CUDA tests.
6
7{
8 absl-py,
9 autoAddDriverRunpath,
10 autoPatchelfHook,
11 buildPythonPackage,
12 config,
13 fetchPypi,
14 fetchurl,
15 flatbuffers,
16 jaxlib-build,
17 lib,
18 ml-dtypes,
19 python,
20 scipy,
21 stdenv,
22 # Options:
23 cudaSupport ? config.cudaSupport,
24 cudaPackages,
25}:
26
27let
28 inherit (cudaPackages) cudaVersion;
29
30 version = "0.4.28";
31
32 inherit (python) pythonVersion;
33
34 cudaLibPath = lib.makeLibraryPath (
35 with cudaPackages;
36 [
37 (lib.getLib cuda_cudart) # libcudart.so
38 (lib.getLib cuda_cupti) # libcupti.so
39 (lib.getLib cudnn) # libcudnn.so
40 (lib.getLib libcufft) # libcufft.so
41 (lib.getLib libcusolver) # libcusolver.so
42 (lib.getLib libcusparse) # libcusparse.so
43 ]
44 );
45
46 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
47 # official instructions recommend installing CPU-only versions via PyPI.
48 cpuSrcs =
49 let
50 getSrcFromPypi =
51 {
52 platform,
53 dist,
54 hash,
55 }:
56 fetchPypi {
57 inherit
58 version
59 platform
60 dist
61 hash
62 ;
63 pname = "jaxlib";
64 format = "wheel";
65 # See the `disabled` attr comment below.
66 python = dist;
67 abi = dist;
68 };
69 in
70 {
71 "3.9-x86_64-linux" = getSrcFromPypi {
72 platform = "manylinux2014_x86_64";
73 dist = "cp39";
74 hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
75 };
76 "3.9-aarch64-darwin" = getSrcFromPypi {
77 platform = "macosx_11_0_arm64";
78 dist = "cp39";
79 hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
80 };
81 "3.9-x86_64-darwin" = getSrcFromPypi {
82 platform = "macosx_10_14_x86_64";
83 dist = "cp39";
84 hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
85 };
86
87 "3.10-x86_64-linux" = getSrcFromPypi {
88 platform = "manylinux2014_x86_64";
89 dist = "cp310";
90 hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
91 };
92 "3.10-aarch64-darwin" = getSrcFromPypi {
93 platform = "macosx_11_0_arm64";
94 dist = "cp310";
95 hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
96 };
97 "3.10-x86_64-darwin" = getSrcFromPypi {
98 platform = "macosx_10_14_x86_64";
99 dist = "cp310";
100 hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
101 };
102
103 "3.11-x86_64-linux" = getSrcFromPypi {
104 platform = "manylinux2014_x86_64";
105 dist = "cp311";
106 hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
107 };
108 "3.11-aarch64-darwin" = getSrcFromPypi {
109 platform = "macosx_11_0_arm64";
110 dist = "cp311";
111 hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
112 };
113 "3.11-x86_64-darwin" = getSrcFromPypi {
114 platform = "macosx_10_14_x86_64";
115 dist = "cp311";
116 hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
117 };
118
119 "3.12-x86_64-linux" = getSrcFromPypi {
120 platform = "manylinux2014_x86_64";
121 dist = "cp312";
122 hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
123 };
124 "3.12-aarch64-darwin" = getSrcFromPypi {
125 platform = "macosx_11_0_arm64";
126 dist = "cp312";
127 hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
128 };
129 "3.12-x86_64-darwin" = getSrcFromPypi {
130 platform = "macosx_10_14_x86_64";
131 dist = "cp312";
132 hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
133 };
134 };
135
136 # Note that the prebuilt jaxlib binary requires specific version of CUDA to
137 # work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11
138 # jaxlib binaries only works with CUDA 11.8. This is why we need to find a
139 # binary that matches the provided cudaVersion.
140 gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}";
141
142 # Find new releases at https://storage.googleapis.com/jax-releases
143 # When upgrading, you can get these hashes from prefetch.sh. See
144 # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
145 gpuSrcs = {
146 "cuda12.2-3.9" = fetchurl {
147 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
148 hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
149 };
150 "cuda12.2-3.10" = fetchurl {
151 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
152 hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
153 };
154 "cuda12.2-3.11" = fetchurl {
155 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
156 hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
157 };
158 "cuda12.2-3.12" = fetchurl {
159 url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
160 hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
161 };
162 };
163in
164buildPythonPackage {
165 pname = "jaxlib";
166 inherit version;
167 format = "wheel";
168
169 disabled =
170 !(
171 pythonVersion == "3.9"
172 || pythonVersion == "3.10"
173 || pythonVersion == "3.11"
174 || pythonVersion == "3.12"
175 );
176
177 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
178 src =
179 if !cudaSupport then
180 (cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
181 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
182 )
183 else
184 gpuSrcs."${gpuSrcVersionString}";
185
186 # Prebuilt wheels are dynamically linked against things that nix can't find.
187 # Run `autoPatchelfHook` to automagically fix them.
188 nativeBuildInputs =
189 lib.optionals stdenv.isLinux [ autoPatchelfHook ]
190 ++ lib.optionals cudaSupport [ autoAddDriverRunpath ];
191 # Dynamic link dependencies
192 buildInputs = [ stdenv.cc.cc.lib ];
193
194 # jaxlib contains shared libraries that open other shared libraries via dlopen
195 # and these implicit dependencies are not recognized by ldd or
196 # autoPatchelfHook. That means we need to sneak them into rpath. This step
197 # must be done after autoPatchelfHook and the automatic stripping of
198 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
199 # patchPhase.
200 preInstallCheck = lib.optional cudaSupport ''
201 shopt -s globstar
202
203 for file in $out/**/*.so; do
204 patchelf --add-rpath "${cudaLibPath}" "$file"
205 done
206 '';
207
208 propagatedBuildInputs = [
209 absl-py
210 flatbuffers
211 ml-dtypes
212 scipy
213 ];
214
215 # jaxlib looks for ptxas at runtime, eg when running `jax.random.PRNGKey(0)`.
216 # Linking into $out is the least bad solution. See
217 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
218 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
219 # for more info.
220 postInstall = lib.optional cudaSupport ''
221 mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
222 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
223 '';
224
225 inherit (jaxlib-build) pythonImportsCheck;
226
227 meta = with lib; {
228 description = "Prebuilt jaxlib backend from PyPi";
229 homepage = "https://github.com/google/jax";
230 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
231 license = licenses.asl20;
232 maintainers = with maintainers; [ samuela ];
233 platforms = [
234 "aarch64-darwin"
235 "x86_64-linux"
236 "x86_64-darwin"
237 ];
238 broken =
239 !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
240 || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
241 || !(cudaSupport -> stdenv.isLinux)
242 || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
243 # Fails at pythonImportsCheckPhase:
244 # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4
245 # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c
246 # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))'
247 || (stdenv.isDarwin && stdenv.isx86_64);
248 };
249}