1{
2 lib,
3 pkgs,
4 stdenv,
5
6 # Build-time dependencies:
7 addDriverRunpath,
8 autoAddDriverRunpath,
9 bazel_6,
10 binutils,
11 buildBazelPackage,
12 buildPythonPackage,
13 cctools,
14 curl,
15 cython,
16 fetchFromGitHub,
17 git,
18 jsoncpp,
19 nsync,
20 openssl,
21 pybind11,
22 setuptools,
23 symlinkJoin,
24 wheel,
25 build,
26 which,
27
28 # Python dependencies:
29 absl-py,
30 flatbuffers,
31 ml-dtypes,
32 numpy,
33 scipy,
34 six,
35
36 # Runtime dependencies:
37 double-conversion,
38 giflib,
39 libjpeg_turbo,
40 python,
41 snappy,
42 zlib,
43
44 config,
45 # CUDA flags:
46 cudaSupport ? config.cudaSupport,
47 cudaPackages,
48
49 # MKL:
50 mklSupport ? true,
51}@inputs:
52
53let
54 inherit (cudaPackages)
55 cudaMajorMinorVersion
56 flags
57 nccl
58 ;
59
60 pname = "jaxlib";
61 version = "0.4.28";
62
63 # It's necessary to consistently use backendStdenv when building with CUDA
64 # support, otherwise we get libstdc++ errors downstream
65 stdenv = throw "Use effectiveStdenv instead";
66 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
67
68 meta = with lib; {
69 description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
70 homepage = "https://github.com/google/jax";
71 license = licenses.asl20;
72 maintainers = with maintainers; [ ndl ];
73
74 # Make this platforms.unix once Darwin is supported.
75 # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
76 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
77 # however even with that fix applied, it doesn't work for everyone:
78 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
79 platforms = platforms.linux;
80 };
81
82 # Bazel wants a merged cudnn at configuration time
83 cudnnMerged = symlinkJoin {
84 name = "cudnn-merged";
85 paths = with cudaPackages; [
86 (lib.getDev cudnn)
87 (lib.getLib cudnn)
88 ];
89 };
90
91 # These are necessary at build time and run time.
92 cuda_libs_joined = symlinkJoin {
93 name = "cuda-joined";
94 paths = with cudaPackages; [
95 (lib.getLib cuda_cudart) # libcudart.so
96 (lib.getLib cuda_cupti) # libcupti.so
97 (lib.getLib libcublas) # libcublas.so
98 (lib.getLib libcufft) # libcufft.so
99 (lib.getLib libcurand) # libcurand.so
100 (lib.getLib libcusolver) # libcusolver.so
101 (lib.getLib libcusparse) # libcusparse.so
102 ];
103 };
104 # These are only necessary at build time.
105 cuda_build_deps_joined = symlinkJoin {
106 name = "cuda-build-deps-joined";
107 paths = with cudaPackages; [
108 cuda_libs_joined
109
110 # Binaries
111 (lib.getBin cuda_nvcc) # nvcc
112
113 # Archives
114 (lib.getOutput "static" cuda_cudart) # libcudart_static.a
115
116 # Headers
117 (lib.getDev cuda_cccl) # block_load.cuh
118 (lib.getDev cuda_cudart) # cuda.h
119 (lib.getDev cuda_cupti) # cupti.h
120 (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811
121 (lib.getDev cuda_nvml_dev) # nvml.h
122 (lib.getDev cuda_nvtx) # nvToolsExt.h
123 (lib.getDev libcublas) # cublas_api.h
124 (lib.getDev libcufft) # cufft.h
125 (lib.getDev libcurand) # curand.h
126 (lib.getDev libcusolver) # cusolver_common.h
127 (lib.getDev libcusparse) # cusparse.h
128 ];
129 };
130
131 backend_cc_joined = symlinkJoin {
132 name = "cuda-cc-joined";
133 paths = [
134 effectiveStdenv.cc
135 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
136 ];
137 };
138
139 # Copy-paste from TF derivation.
140 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
141 # 'as is' so that it's more compatible with TF derivation.
142 tf_system_libs = [
143 "absl_py"
144 "astor_archive"
145 "astunparse_archive"
146 # Not packaged in nixpkgs
147 # "com_github_googleapis_googleapis"
148 # "com_github_googlecloudplatform_google_cloud_cpp"
149 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
150 # "com_github_grpc_grpc"
151 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
152 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
153 # "com_google_protobuf"
154 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
155 # "com_googlesource_code_re2"
156 "curl"
157 "cython"
158 "dill_archive"
159 "double_conversion"
160 "flatbuffers"
161 "functools32_archive"
162 "gast_archive"
163 "gif"
164 "hwloc"
165 "icu"
166 "jsoncpp_git"
167 "libjpeg_turbo"
168 "lmdb"
169 "nasm"
170 "opt_einsum_archive"
171 "org_sqlite"
172 "pasta"
173 "png"
174 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
175 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
176 # "pybind11"
177 "six_archive"
178 "snappy"
179 "tblib_archive"
180 "termcolor_archive"
181 "typing_extensions_archive"
182 "wrapt"
183 "zlib"
184 ];
185
186 arch =
187 # KeyError: ('Linux', 'arm64')
188 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then
189 "aarch64"
190 else
191 effectiveStdenv.hostPlatform.linuxArch;
192
193 xla = effectiveStdenv.mkDerivation {
194 pname = "xla-src";
195 version = "unstable";
196
197 src = fetchFromGitHub {
198 owner = "openxla";
199 repo = "xla";
200 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
201 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
202 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
203 };
204
205 dontBuild = true;
206
207 # This is necessary for patchShebangs to know the right path to use.
208 nativeBuildInputs = [ python ];
209
210 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
211 postPatch = ''
212 patchShebangs .
213 '';
214
215 installPhase = ''
216 cp -r . $out
217 '';
218 };
219
220 bazel-build = buildBazelPackage rec {
221 name = "bazel-build-${pname}-${version}";
222
223 # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
224 bazel = bazel_6;
225
226 src = fetchFromGitHub {
227 owner = "google";
228 repo = "jax";
229 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
230 rev = "refs/tags/${pname}-v${version}";
231 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
232 };
233
234 nativeBuildInputs = [
235 cython
236 pkgs.flatbuffers
237 git
238 setuptools
239 wheel
240 build
241 which
242 ] ++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [ cctools ];
243
244 buildInputs = [
245 curl
246 double-conversion
247 giflib
248 jsoncpp
249 libjpeg_turbo
250 numpy
251 openssl
252 pkgs.flatbuffers
253 pkgs.protobuf
254 pybind11
255 scipy
256 six
257 snappy
258 zlib
259 ] ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [ nsync ];
260
261 # We don't want to be quite so picky regarding bazel version
262 postPatch = ''
263 rm -f .bazelversion
264 '';
265
266 bazelRunTarget = "//jaxlib/tools:build_wheel";
267 runTargetFlags = [
268 "--output_path=$out"
269 "--cpu=${arch}"
270 # This has no impact whatsoever...
271 "--jaxlib_git_hash='12345678'"
272 ];
273
274 removeRulesCC = false;
275
276 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
277 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
278
279 # The version is automatically set to ".dev" if this variable is not set.
280 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
281 JAXLIB_RELEASE = "1";
282
283 preConfigure =
284 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
285 ''
286 mkdir dummy-ldconfig
287 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
288 chmod +x dummy-ldconfig/ldconfig
289 export PATH="$PWD/dummy-ldconfig:$PATH"
290 ''
291 +
292
293 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
294 # for more info. We assume
295 # * `cpu = None`
296 # * `enable_nccl = True`
297 # * `target_cpu_features = "release"`
298 # * `rocm_amdgpu_targets = None`
299 # * `enable_rocm = False`
300 # * `build_gpu_plugin = False`
301 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
302 #
303 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
304 # instead of duplicating the logic here. Perhaps we can leverage the
305 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
306 ''
307 cat <<CFG > ./.jax_configure.bazelrc
308 build --strategy=Genrule=standalone
309 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
310 build --action_env=PYENV_ROOT
311 build --python_path="${python}/bin/python"
312 build --distinct_host_configuration=false
313 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
314 ''
315 + lib.optionalString cudaSupport ''
316 build --config=cuda
317 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
318 build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}"
319 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}"
320 build --action_env TF_CUDA_VERSION="${cudaMajorMinorVersion}"
321 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}"
322 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," flags.realArches}"
323 ''
324 +
325 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
326 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
327 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
328 # for upstream's version.
329 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
330 ''
331 build --config=avx_posix
332 ''
333 + lib.optionalString mklSupport ''
334 build --config=mkl_open_source_only
335 ''
336 + ''
337 CFG
338 '';
339
340 # Make sure Bazel knows about our configuration flags during fetching so that the
341 # relevant dependencies can be downloaded.
342 bazelFlags =
343 [
344 "-c opt"
345 # See https://bazel.build/external/advanced#overriding-repositories for
346 # information on --override_repository flag.
347 "--override_repository=xla=${xla}"
348 ]
349 ++ lib.optionals effectiveStdenv.cc.isClang [
350 # bazel depends on the compiler frontend automatically selecting these flags based on file
351 # extension but our clang doesn't.
352 # https://github.com/NixOS/nixpkgs/issues/150655
353 "--cxxopt=-x"
354 "--cxxopt=c++"
355 "--host_cxxopt=-x"
356 "--host_cxxopt=c++"
357 ];
358
359 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
360 fetchAttrs = {
361 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
362 # we have to force @mkl_dnn_v1 since it's not needed on darwin
363 bazelTargets = [
364 bazelRunTarget
365 "@mkl_dnn_v1//:mkl_dnn"
366 ];
367 bazelFlags =
368 bazelFlags
369 ++ [
370 "--config=avx_posix"
371 "--config=mkl_open_source_only"
372 ]
373 ++ lib.optionals cudaSupport [
374 # ideally we'd add this unconditionally too, but it doesn't work on darwin
375 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
376 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
377 # have access to darwin machines
378 "--config=cuda"
379 ];
380
381 sha256 =
382 (
383 if cudaSupport then
384 { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; }
385 else
386 {
387 x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw=";
388 aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM=";
389 }
390 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
391
392 # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546
393 preInstall = ''
394 cat << \EOF > "$bazelOut/external/go_sdk/versions.json"
395 []
396 EOF
397 '';
398 };
399
400 buildAttrs = {
401 outputs = [ "out" ];
402
403 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
404 tf_system_libs
405 ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [
406 "nsync" # fails to build on darwin
407 ]
408 );
409
410 # Note: we cannot do most of this patching at `patch` phase as the deps
411 # are not available yet.
412 preBuild = lib.optionalString effectiveStdenv.hostPlatform.isDarwin ''
413 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
414 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
415 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
416 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
417 '';
418 };
419
420 inherit meta;
421 };
422 platformTag =
423 if effectiveStdenv.hostPlatform.isLinux then
424 "manylinux2014_${arch}"
425 else if effectiveStdenv.system == "x86_64-darwin" then
426 "macosx_10_9_${arch}"
427 else if effectiveStdenv.system == "aarch64-darwin" then
428 "macosx_11_0_${arch}"
429 else
430 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
431in
432buildPythonPackage {
433 inherit pname version;
434 format = "wheel";
435
436 src =
437 let
438 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}";
439 in
440 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
441
442 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
443 # for more info.
444 postInstall = lib.optionalString cudaSupport ''
445 mkdir -p $out/bin
446 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas
447
448 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
449 patchelf --add-rpath "${
450 lib.makeLibraryPath [
451 cuda_libs_joined
452 (lib.getLib cudaPackages.cudnn)
453 nccl
454 ]
455 }" "$lib"
456 done
457 '';
458
459 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
460
461 dependencies = [
462 absl-py
463 curl
464 double-conversion
465 flatbuffers
466 giflib
467 jsoncpp
468 libjpeg_turbo
469 ml-dtypes
470 numpy
471 scipy
472 six
473 snappy
474 ];
475
476 pythonImportsCheck = [
477 "jaxlib"
478 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
479 "jaxlib.cpu_feature_guard"
480 "jaxlib.xla_client"
481 ];
482
483 # Without it there are complaints about libcudart.so.11.0 not being found
484 # because RPATH path entries added above are stripped.
485 dontPatchELF = cudaSupport;
486
487 passthru = {
488 # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps`
489 inherit bazel-build;
490 };
491
492 inherit meta;
493}