1{
2 lib,
3 pkgs,
4 stdenv,
5
6 # Build-time dependencies:
7 addDriverRunpath,
8 autoAddDriverRunpath,
9 bazel_6,
10 binutils,
11 buildBazelPackage,
12 buildPythonPackage,
13 cctools,
14 curl,
15 cython,
16 fetchFromGitHub,
17 git,
18 IOKit,
19 jsoncpp,
20 nsync,
21 openssl,
22 pybind11,
23 setuptools,
24 symlinkJoin,
25 wheel,
26 build,
27 which,
28
29 # Python dependencies:
30 absl-py,
31 flatbuffers,
32 ml-dtypes,
33 numpy,
34 scipy,
35 six,
36
37 # Runtime dependencies:
38 double-conversion,
39 giflib,
40 libjpeg_turbo,
41 python,
42 snappy,
43 zlib,
44
45 config,
46 # CUDA flags:
47 cudaSupport ? config.cudaSupport,
48 cudaPackages,
49
50 # MKL:
51 mklSupport ? true,
52}@inputs:
53
54let
55 inherit (cudaPackages)
56 cudaFlags
57 cudaVersion
58 nccl
59 ;
60
61 pname = "jaxlib";
62 version = "0.4.28";
63
64 # It's necessary to consistently use backendStdenv when building with CUDA
65 # support, otherwise we get libstdc++ errors downstream
66 stdenv = throw "Use effectiveStdenv instead";
67 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
68
69 meta = with lib; {
70 description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
71 homepage = "https://github.com/google/jax";
72 license = licenses.asl20;
73 maintainers = with maintainers; [ ndl ];
74
75 # Make this platforms.unix once Darwin is supported.
76 # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
77 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
78 # however even with that fix applied, it doesn't work for everyone:
79 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
80 platforms = platforms.linux;
81 };
82
83 # Bazel wants a merged cudnn at configuration time
84 cudnnMerged = symlinkJoin {
85 name = "cudnn-merged";
86 paths = with cudaPackages; [
87 (lib.getDev cudnn)
88 (lib.getLib cudnn)
89 ];
90 };
91
92 # These are necessary at build time and run time.
93 cuda_libs_joined = symlinkJoin {
94 name = "cuda-joined";
95 paths = with cudaPackages; [
96 (lib.getLib cuda_cudart) # libcudart.so
97 (lib.getLib cuda_cupti) # libcupti.so
98 (lib.getLib libcublas) # libcublas.so
99 (lib.getLib libcufft) # libcufft.so
100 (lib.getLib libcurand) # libcurand.so
101 (lib.getLib libcusolver) # libcusolver.so
102 (lib.getLib libcusparse) # libcusparse.so
103 ];
104 };
105 # These are only necessary at build time.
106 cuda_build_deps_joined = symlinkJoin {
107 name = "cuda-build-deps-joined";
108 paths = with cudaPackages; [
109 cuda_libs_joined
110
111 # Binaries
112 (lib.getBin cuda_nvcc) # nvcc
113
114 # Archives
115 (lib.getOutput "static" cuda_cudart) # libcudart_static.a
116
117 # Headers
118 (lib.getDev cuda_cccl) # block_load.cuh
119 (lib.getDev cuda_cudart) # cuda.h
120 (lib.getDev cuda_cupti) # cupti.h
121 (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811
122 (lib.getDev cuda_nvml_dev) # nvml.h
123 (lib.getDev cuda_nvtx) # nvToolsExt.h
124 (lib.getDev libcublas) # cublas_api.h
125 (lib.getDev libcufft) # cufft.h
126 (lib.getDev libcurand) # curand.h
127 (lib.getDev libcusolver) # cusolver_common.h
128 (lib.getDev libcusparse) # cusparse.h
129 ];
130 };
131
132 backend_cc_joined = symlinkJoin {
133 name = "cuda-cc-joined";
134 paths = [
135 effectiveStdenv.cc
136 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
137 ];
138 };
139
140 # Copy-paste from TF derivation.
141 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
142 # 'as is' so that it's more compatible with TF derivation.
143 tf_system_libs = [
144 "absl_py"
145 "astor_archive"
146 "astunparse_archive"
147 # Not packaged in nixpkgs
148 # "com_github_googleapis_googleapis"
149 # "com_github_googlecloudplatform_google_cloud_cpp"
150 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
151 # "com_github_grpc_grpc"
152 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
153 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
154 # "com_google_protobuf"
155 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
156 # "com_googlesource_code_re2"
157 "curl"
158 "cython"
159 "dill_archive"
160 "double_conversion"
161 "flatbuffers"
162 "functools32_archive"
163 "gast_archive"
164 "gif"
165 "hwloc"
166 "icu"
167 "jsoncpp_git"
168 "libjpeg_turbo"
169 "lmdb"
170 "nasm"
171 "opt_einsum_archive"
172 "org_sqlite"
173 "pasta"
174 "png"
175 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
176 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
177 # "pybind11"
178 "six_archive"
179 "snappy"
180 "tblib_archive"
181 "termcolor_archive"
182 "typing_extensions_archive"
183 "wrapt"
184 "zlib"
185 ];
186
187 arch =
188 # KeyError: ('Linux', 'arm64')
189 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then
190 "aarch64"
191 else
192 effectiveStdenv.hostPlatform.linuxArch;
193
194 xla = effectiveStdenv.mkDerivation {
195 pname = "xla-src";
196 version = "unstable";
197
198 src = fetchFromGitHub {
199 owner = "openxla";
200 repo = "xla";
201 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
202 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
203 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
204 };
205
206 dontBuild = true;
207
208 # This is necessary for patchShebangs to know the right path to use.
209 nativeBuildInputs = [ python ];
210
211 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
212 postPatch = ''
213 patchShebangs .
214 '';
215
216 installPhase = ''
217 cp -r . $out
218 '';
219 };
220
221 bazel-build = buildBazelPackage rec {
222 name = "bazel-build-${pname}-${version}";
223
224 # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
225 bazel = bazel_6;
226
227 src = fetchFromGitHub {
228 owner = "google";
229 repo = "jax";
230 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
231 rev = "refs/tags/${pname}-v${version}";
232 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
233 };
234
235 nativeBuildInputs = [
236 cython
237 pkgs.flatbuffers
238 git
239 setuptools
240 wheel
241 build
242 which
243 ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ];
244
245 buildInputs =
246 [
247 curl
248 double-conversion
249 giflib
250 jsoncpp
251 libjpeg_turbo
252 numpy
253 openssl
254 pkgs.flatbuffers
255 pkgs.protobuf
256 pybind11
257 scipy
258 six
259 snappy
260 zlib
261 ]
262 ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ]
263 ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ];
264
265 # We don't want to be quite so picky regarding bazel version
266 postPatch = ''
267 rm -f .bazelversion
268 '';
269
270 bazelRunTarget = "//jaxlib/tools:build_wheel";
271 runTargetFlags = [
272 "--output_path=$out"
273 "--cpu=${arch}"
274 # This has no impact whatsoever...
275 "--jaxlib_git_hash='12345678'"
276 ];
277
278 removeRulesCC = false;
279
280 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
281 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
282
283 # The version is automatically set to ".dev" if this variable is not set.
284 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
285 JAXLIB_RELEASE = "1";
286
287 preConfigure =
288 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
289 ''
290 mkdir dummy-ldconfig
291 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
292 chmod +x dummy-ldconfig/ldconfig
293 export PATH="$PWD/dummy-ldconfig:$PATH"
294 ''
295 +
296
297 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
298 # for more info. We assume
299 # * `cpu = None`
300 # * `enable_nccl = True`
301 # * `target_cpu_features = "release"`
302 # * `rocm_amdgpu_targets = None`
303 # * `enable_rocm = False`
304 # * `build_gpu_plugin = False`
305 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
306 #
307 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
308 # instead of duplicating the logic here. Perhaps we can leverage the
309 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
310 ''
311 cat <<CFG > ./.jax_configure.bazelrc
312 build --strategy=Genrule=standalone
313 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
314 build --action_env=PYENV_ROOT
315 build --python_path="${python}/bin/python"
316 build --distinct_host_configuration=false
317 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
318 ''
319 + lib.optionalString cudaSupport ''
320 build --config=cuda
321 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
322 build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}"
323 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}"
324 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
325 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}"
326 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
327 ''
328 +
329 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
330 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
331 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
332 # for upstream's version.
333 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
334 ''
335 build --config=avx_posix
336 ''
337 + lib.optionalString mklSupport ''
338 build --config=mkl_open_source_only
339 ''
340 + ''
341 CFG
342 '';
343
344 # Make sure Bazel knows about our configuration flags during fetching so that the
345 # relevant dependencies can be downloaded.
346 bazelFlags =
347 [
348 "-c opt"
349 # See https://bazel.build/external/advanced#overriding-repositories for
350 # information on --override_repository flag.
351 "--override_repository=xla=${xla}"
352 ]
353 ++ lib.optionals effectiveStdenv.cc.isClang [
354 # bazel depends on the compiler frontend automatically selecting these flags based on file
355 # extension but our clang doesn't.
356 # https://github.com/NixOS/nixpkgs/issues/150655
357 "--cxxopt=-x"
358 "--cxxopt=c++"
359 "--host_cxxopt=-x"
360 "--host_cxxopt=c++"
361 ];
362
363 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
364 fetchAttrs = {
365 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
366 # we have to force @mkl_dnn_v1 since it's not needed on darwin
367 bazelTargets = [
368 bazelRunTarget
369 "@mkl_dnn_v1//:mkl_dnn"
370 ];
371 bazelFlags =
372 bazelFlags
373 ++ [
374 "--config=avx_posix"
375 "--config=mkl_open_source_only"
376 ]
377 ++ lib.optionals cudaSupport [
378 # ideally we'd add this unconditionally too, but it doesn't work on darwin
379 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
380 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
381 # have access to darwin machines
382 "--config=cuda"
383 ];
384
385 sha256 =
386 (
387 if cudaSupport then
388 { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; }
389 else
390 {
391 x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw=";
392 aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM=";
393 }
394 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
395
396 # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546
397 preInstall = ''
398 cat << \EOF > "$bazelOut/external/go_sdk/versions.json"
399 []
400 EOF
401 '';
402 };
403
404 buildAttrs = {
405 outputs = [ "out" ];
406
407 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
408 tf_system_libs
409 ++ lib.optionals (!effectiveStdenv.isDarwin) [
410 "nsync" # fails to build on darwin
411 ]
412 );
413
414 # Note: we cannot do most of this patching at `patch` phase as the deps
415 # are not available yet. Framework search paths aren't added by bintools
416 # hook. See https://github.com/NixOS/nixpkgs/pull/41914.
417 preBuild = lib.optionalString effectiveStdenv.isDarwin ''
418 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
419 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
420 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
421 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
422 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
423 '';
424 };
425
426 inherit meta;
427 };
428 platformTag =
429 if effectiveStdenv.hostPlatform.isLinux then
430 "manylinux2014_${arch}"
431 else if effectiveStdenv.system == "x86_64-darwin" then
432 "macosx_10_9_${arch}"
433 else if effectiveStdenv.system == "aarch64-darwin" then
434 "macosx_11_0_${arch}"
435 else
436 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
437in
438buildPythonPackage {
439 inherit pname version;
440 format = "wheel";
441
442 src =
443 let
444 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}";
445 in
446 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
447
448 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
449 # for more info.
450 postInstall = lib.optionalString cudaSupport ''
451 mkdir -p $out/bin
452 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas
453
454 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
455 patchelf --add-rpath "${
456 lib.makeLibraryPath [
457 cuda_libs_joined
458 (lib.getLib cudaPackages.cudnn)
459 nccl
460 ]
461 }" "$lib"
462 done
463 '';
464
465 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
466
467 dependencies = [
468 absl-py
469 curl
470 double-conversion
471 flatbuffers
472 giflib
473 jsoncpp
474 libjpeg_turbo
475 ml-dtypes
476 numpy
477 scipy
478 six
479 snappy
480 ];
481
482 pythonImportsCheck = [
483 "jaxlib"
484 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
485 "jaxlib.cpu_feature_guard"
486 "jaxlib.xla_client"
487 ];
488
489 # Without it there are complaints about libcudart.so.11.0 not being found
490 # because RPATH path entries added above are stripped.
491 dontPatchELF = cudaSupport;
492
493 passthru = {
494 # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps`
495 inherit bazel-build;
496 };
497
498 inherit meta;
499}