1{
2 lib,
3 pkgs,
4 stdenv,
5
6 # Build-time dependencies:
7 addOpenGLRunpath,
8 autoAddDriverRunpath,
9 bazel_6,
10 binutils,
11 buildBazelPackage,
12 buildPythonPackage,
13 cctools,
14 curl,
15 cython,
16 fetchFromGitHub,
17 git,
18 IOKit,
19 jsoncpp,
20 nsync,
21 openssl,
22 pybind11,
23 setuptools,
24 symlinkJoin,
25 wheel,
26 build,
27 which,
28
29 # Python dependencies:
30 absl-py,
31 flatbuffers,
32 ml-dtypes,
33 numpy,
34 scipy,
35 six,
36
37 # Runtime dependencies:
38 double-conversion,
39 giflib,
40 libjpeg_turbo,
41 python,
42 snappy,
43 zlib,
44
45 config,
46 # CUDA flags:
47 cudaSupport ? config.cudaSupport,
48 cudaPackages,
49
50 # MKL:
51 mklSupport ? true,
52}@inputs:
53
54let
55 inherit (cudaPackages)
56 cudaFlags
57 cudaVersion
58 cudnn
59 nccl
60 ;
61
62 pname = "jaxlib";
63 version = "0.4.28";
64
65 # It's necessary to consistently use backendStdenv when building with CUDA
66 # support, otherwise we get libstdc++ errors downstream
67 stdenv = throw "Use effectiveStdenv instead";
68 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
69
70 meta = with lib; {
71 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
72 homepage = "https://github.com/google/jax";
73 license = licenses.asl20;
74 maintainers = with maintainers; [ ndl ];
75 platforms = platforms.unix;
76 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
77 # however even with that fix applied, it doesn't work for everyone:
78 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
79 # NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
80 broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
81 };
82
83 # These are necessary at build time and run time.
84 cuda_libs_joined = symlinkJoin {
85 name = "cuda-joined";
86 paths = with cudaPackages; [
87 cuda_cudart.lib # libcudart.so
88 cuda_cudart.static # libcudart_static.a
89 cuda_cupti.lib # libcupti.so
90 libcublas.lib # libcublas.so
91 libcufft.lib # libcufft.so
92 libcurand.lib # libcurand.so
93 libcusolver.lib # libcusolver.so
94 libcusparse.lib # libcusparse.so
95 ];
96 };
97 # These are only necessary at build time.
98 cuda_build_deps_joined = symlinkJoin {
99 name = "cuda-build-deps-joined";
100 paths = with cudaPackages; [
101 cuda_libs_joined
102
103 # Binaries
104 cudaPackages.cuda_nvcc.bin # nvcc
105
106 # Headers
107 cuda_cccl.dev # block_load.cuh
108 cuda_cudart.dev # cuda.h
109 cuda_cupti.dev # cupti.h
110 cuda_nvcc.dev # See https://github.com/google/jax/issues/19811
111 cuda_nvml_dev # nvml.h
112 cuda_nvtx.dev # nvToolsExt.h
113 libcublas.dev # cublas_api.h
114 libcufft.dev # cufft.h
115 libcurand.dev # curand.h
116 libcusolver.dev # cusolver_common.h
117 libcusparse.dev # cusparse.h
118 ];
119 };
120
121 backend_cc_joined = symlinkJoin {
122 name = "cuda-cc-joined";
123 paths = [
124 effectiveStdenv.cc
125 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
126 ];
127 };
128
129 # Copy-paste from TF derivation.
130 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
131 # 'as is' so that it's more compatible with TF derivation.
132 tf_system_libs = [
133 "absl_py"
134 "astor_archive"
135 "astunparse_archive"
136 # Not packaged in nixpkgs
137 # "com_github_googleapis_googleapis"
138 # "com_github_googlecloudplatform_google_cloud_cpp"
139 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
140 # "com_github_grpc_grpc"
141 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
142 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
143 # "com_google_protobuf"
144 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
145 # "com_googlesource_code_re2"
146 "curl"
147 "cython"
148 "dill_archive"
149 "double_conversion"
150 "flatbuffers"
151 "functools32_archive"
152 "gast_archive"
153 "gif"
154 "hwloc"
155 "icu"
156 "jsoncpp_git"
157 "libjpeg_turbo"
158 "lmdb"
159 "nasm"
160 "opt_einsum_archive"
161 "org_sqlite"
162 "pasta"
163 "png"
164 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
165 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
166 # "pybind11"
167 "six_archive"
168 "snappy"
169 "tblib_archive"
170 "termcolor_archive"
171 "typing_extensions_archive"
172 "wrapt"
173 "zlib"
174 ];
175
176 arch =
177 # KeyError: ('Linux', 'arm64')
178 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then
179 "aarch64"
180 else
181 effectiveStdenv.hostPlatform.linuxArch;
182
183 xla = effectiveStdenv.mkDerivation {
184 pname = "xla-src";
185 version = "unstable";
186
187 src = fetchFromGitHub {
188 owner = "openxla";
189 repo = "xla";
190 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
191 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
192 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
193 };
194
195 dontBuild = true;
196
197 # This is necessary for patchShebangs to know the right path to use.
198 nativeBuildInputs = [ python ];
199
200 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
201 postPatch = ''
202 patchShebangs .
203 '';
204
205 installPhase = ''
206 cp -r . $out
207 '';
208 };
209
210 bazel-build = buildBazelPackage rec {
211 name = "bazel-build-${pname}-${version}";
212
213 # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
214 bazel = bazel_6;
215
216 src = fetchFromGitHub {
217 owner = "google";
218 repo = "jax";
219 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
220 rev = "refs/tags/${pname}-v${version}";
221 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
222 };
223
224 nativeBuildInputs = [
225 cython
226 pkgs.flatbuffers
227 git
228 setuptools
229 wheel
230 build
231 which
232 ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ];
233
234 buildInputs =
235 [
236 curl
237 double-conversion
238 giflib
239 jsoncpp
240 libjpeg_turbo
241 numpy
242 openssl
243 pkgs.flatbuffers
244 pkgs.protobuf
245 pybind11
246 scipy
247 six
248 snappy
249 zlib
250 ]
251 ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ]
252 ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ];
253
254 # We don't want to be quite so picky regarding bazel version
255 postPatch = ''
256 rm -f .bazelversion
257 '';
258
259 bazelRunTarget = "//jaxlib/tools:build_wheel";
260 runTargetFlags = [
261 "--output_path=$out"
262 "--cpu=${arch}"
263 # This has no impact whatsoever...
264 "--jaxlib_git_hash='12345678'"
265 ];
266
267 removeRulesCC = false;
268
269 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
270 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
271
272 # The version is automatically set to ".dev" if this variable is not set.
273 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
274 JAXLIB_RELEASE = "1";
275
276 preConfigure =
277 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
278 ''
279 mkdir dummy-ldconfig
280 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
281 chmod +x dummy-ldconfig/ldconfig
282 export PATH="$PWD/dummy-ldconfig:$PATH"
283 ''
284 +
285
286 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
287 # for more info. We assume
288 # * `cpu = None`
289 # * `enable_nccl = True`
290 # * `target_cpu_features = "release"`
291 # * `rocm_amdgpu_targets = None`
292 # * `enable_rocm = False`
293 # * `build_gpu_plugin = False`
294 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
295 #
296 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
297 # instead of duplicating the logic here. Perhaps we can leverage the
298 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
299 ''
300 cat <<CFG > ./.jax_configure.bazelrc
301 build --strategy=Genrule=standalone
302 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
303 build --action_env=PYENV_ROOT
304 build --python_path="${python}/bin/python"
305 build --distinct_host_configuration=false
306 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
307 ''
308 + lib.optionalString cudaSupport ''
309 build --config=cuda
310 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
311 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
312 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}"
313 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
314 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
315 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
316 ''
317 +
318 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
319 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
320 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
321 # for upstream's version.
322 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
323 ''
324 build --config=avx_posix
325 ''
326 + lib.optionalString mklSupport ''
327 build --config=mkl_open_source_only
328 ''
329 + ''
330 CFG
331 '';
332
333 # Make sure Bazel knows about our configuration flags during fetching so that the
334 # relevant dependencies can be downloaded.
335 bazelFlags =
336 [
337 "-c opt"
338 # See https://bazel.build/external/advanced#overriding-repositories for
339 # information on --override_repository flag.
340 "--override_repository=xla=${xla}"
341 ]
342 ++ lib.optionals effectiveStdenv.cc.isClang [
343 # bazel depends on the compiler frontend automatically selecting these flags based on file
344 # extension but our clang doesn't.
345 # https://github.com/NixOS/nixpkgs/issues/150655
346 "--cxxopt=-x"
347 "--cxxopt=c++"
348 "--host_cxxopt=-x"
349 "--host_cxxopt=c++"
350 ];
351
352 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
353 fetchAttrs = {
354 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
355 # we have to force @mkl_dnn_v1 since it's not needed on darwin
356 bazelTargets = [
357 bazelRunTarget
358 "@mkl_dnn_v1//:mkl_dnn"
359 ];
360 bazelFlags =
361 bazelFlags
362 ++ [
363 "--config=avx_posix"
364 "--config=mkl_open_source_only"
365 ]
366 ++ lib.optionals cudaSupport [
367 # ideally we'd add this unconditionally too, but it doesn't work on darwin
368 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
369 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
370 # have access to darwin machines
371 "--config=cuda"
372 ];
373
374 sha256 =
375 (
376 if cudaSupport then
377 { x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; }
378 else
379 {
380 x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
381 aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
382 }
383 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
384 };
385
386 buildAttrs = {
387 outputs = [ "out" ];
388
389 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
390 tf_system_libs
391 ++ lib.optionals (!effectiveStdenv.isDarwin) [
392 "nsync" # fails to build on darwin
393 ]
394 );
395
396 # Note: we cannot do most of this patching at `patch` phase as the deps
397 # are not available yet. Framework search paths aren't added by bintools
398 # hook. See https://github.com/NixOS/nixpkgs/pull/41914.
399 preBuild = lib.optionalString effectiveStdenv.isDarwin ''
400 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
401 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
402 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
403 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
404 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
405 '';
406 };
407
408 inherit meta;
409 };
410 platformTag =
411 if effectiveStdenv.hostPlatform.isLinux then
412 "manylinux2014_${arch}"
413 else if effectiveStdenv.system == "x86_64-darwin" then
414 "macosx_10_9_${arch}"
415 else if effectiveStdenv.system == "aarch64-darwin" then
416 "macosx_11_0_${arch}"
417 else
418 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
419in
420buildPythonPackage {
421 inherit meta pname version;
422 format = "wheel";
423
424 src =
425 let
426 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}";
427 in
428 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
429
430 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
431 # for more info.
432 postInstall = lib.optionalString cudaSupport ''
433 mkdir -p $out/bin
434 ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
435
436 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
437 patchelf --add-rpath "${
438 lib.makeLibraryPath [
439 cuda_libs_joined
440 cudnn
441 nccl
442 ]
443 }" "$lib"
444 done
445 '';
446
447 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
448
449 dependencies = [
450 absl-py
451 curl
452 double-conversion
453 flatbuffers
454 giflib
455 jsoncpp
456 libjpeg_turbo
457 ml-dtypes
458 numpy
459 scipy
460 six
461 snappy
462 ];
463
464 pythonImportsCheck = [
465 "jaxlib"
466 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
467 "jaxlib.cpu_feature_guard"
468 "jaxlib.xla_client"
469 ];
470
471 # Without it there are complaints about libcudart.so.11.0 not being found
472 # because RPATH path entries added above are stripped.
473 dontPatchELF = cudaSupport;
474}