1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_6
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cctools
12, curl
13, cython
14, fetchFromGitHub
15, git
16, IOKit
17, jsoncpp
18, nsync
19, openssl
20, pybind11
21, setuptools
22, symlinkJoin
23, wheel
24, build
25, which
26
27 # Python dependencies:
28, absl-py
29, flatbuffers
30, ml-dtypes
31, numpy
32, scipy
33, six
34
35 # Runtime dependencies:
36, double-conversion
37, giflib
38, grpc
39, libjpeg_turbo
40, python
41, snappy
42, zlib
43
44, config
45 # CUDA flags:
46, cudaSupport ? config.cudaSupport
47, cudaPackages ? {}
48
49 # MKL:
50, mklSupport ? true
51}:
52
53let
54 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
55
56 pname = "jaxlib";
57 version = "0.4.20";
58
59 meta = with lib; {
60 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
61 homepage = "https://github.com/google/jax";
62 license = licenses.asl20;
63 maintainers = with maintainers; [ ndl ];
64 platforms = platforms.unix;
65 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
66 # however even with that fix applied, it doesn't work for everyone:
67 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
68 broken = stdenv.isDarwin;
69 };
70
71 cudatoolkit_joined = symlinkJoin {
72 name = "${cudatoolkit.name}-merged";
73 paths = [
74 cudatoolkit.lib
75 cudatoolkit.out
76 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
77 # for some reason some of the required libs are in the targets/x86_64-linux
78 # directory; not sure why but this works around it
79 "${cudatoolkit}/targets/${stdenv.system}"
80 ];
81 };
82
83 cudatoolkit_cc_joined = symlinkJoin {
84 name = "${cudatoolkit.cc.name}-merged";
85 paths = [
86 backendStdenv.cc
87 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
88 ];
89 };
90
91 # Copy-paste from TF derivation.
92 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
93 # 'as is' so that it's more compatible with TF derivation.
94 tf_system_libs = [
95 "absl_py"
96 "astor_archive"
97 "astunparse_archive"
98 # Not packaged in nixpkgs
99 # "com_github_googleapis_googleapis"
100 # "com_github_googlecloudplatform_google_cloud_cpp"
101 "com_github_grpc_grpc"
102 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
103 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
104 # "com_google_protobuf"
105 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
106 # "com_googlesource_code_re2"
107 "curl"
108 "cython"
109 "dill_archive"
110 "double_conversion"
111 "flatbuffers"
112 "functools32_archive"
113 "gast_archive"
114 "gif"
115 "hwloc"
116 "icu"
117 "jsoncpp_git"
118 "libjpeg_turbo"
119 "lmdb"
120 "nasm"
121 "opt_einsum_archive"
122 "org_sqlite"
123 "pasta"
124 "png"
125 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
126 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
127 # "pybind11"
128 "six_archive"
129 "snappy"
130 "tblib_archive"
131 "termcolor_archive"
132 "typing_extensions_archive"
133 "wrapt"
134 "zlib"
135 ];
136
137 arch =
138 # KeyError: ('Linux', 'arm64')
139 if stdenv.hostPlatform.isLinux && stdenv.hostPlatform.linuxArch == "arm64" then "aarch64"
140 else stdenv.hostPlatform.linuxArch;
141
142 bazel-build = buildBazelPackage rec {
143 name = "bazel-build-${pname}-${version}";
144
145 # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
146 bazel = bazel_6;
147
148 src = fetchFromGitHub {
149 owner = "google";
150 repo = "jax";
151 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
152 rev = "refs/tags/${pname}-v${version}";
153 hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU=";
154 };
155
156 nativeBuildInputs = [
157 cython
158 pkgs.flatbuffers
159 git
160 setuptools
161 wheel
162 build
163 which
164 ] ++ lib.optionals stdenv.isDarwin [
165 cctools
166 ];
167
168 buildInputs = [
169 curl
170 double-conversion
171 giflib
172 grpc
173 jsoncpp
174 libjpeg_turbo
175 numpy
176 openssl
177 pkgs.flatbuffers
178 pkgs.protobuf
179 pybind11
180 scipy
181 six
182 snappy
183 zlib
184 ] ++ lib.optionals cudaSupport [
185 cudatoolkit
186 cudnn
187 ] ++ lib.optionals stdenv.isDarwin [
188 IOKit
189 ] ++ lib.optionals (!stdenv.isDarwin) [
190 nsync
191 ];
192
193 postPatch = ''
194 rm -f .bazelversion
195 '';
196
197 bazelRunTarget = "//jaxlib/tools:build_wheel";
198 runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
199
200 removeRulesCC = false;
201
202 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
203 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
204
205 # The version is automatically set to ".dev" if this variable is not set.
206 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
207 JAXLIB_RELEASE = "1";
208
209 preConfigure = ''
210 # dummy ldconfig
211 mkdir dummy-ldconfig
212 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
213 chmod +x dummy-ldconfig/ldconfig
214 export PATH="$PWD/dummy-ldconfig:$PATH"
215 cat <<CFG > ./.jax_configure.bazelrc
216 build --strategy=Genrule=standalone
217 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
218 build --action_env=PYENV_ROOT
219 build --python_path="${python}/bin/python"
220 build --distinct_host_configuration=false
221 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
222 '' + lib.optionalString (stdenv.hostPlatform.avxSupport && stdenv.hostPlatform.isUnix) ''
223 build --config=avx_posix
224 '' + lib.optionalString mklSupport ''
225 build --config=mkl_open_source_only
226 '' + lib.optionalString cudaSupport ''
227 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
228 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
229 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
230 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
231 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
232 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
233 '' + ''
234 CFG
235 '';
236
237 # Make sure Bazel knows about our configuration flags during fetching so that the
238 # relevant dependencies can be downloaded.
239 bazelFlags = [
240 "-c opt"
241 ] ++ lib.optionals stdenv.cc.isClang [
242 # bazel depends on the compiler frontend automatically selecting these flags based on file
243 # extension but our clang doesn't.
244 # https://github.com/NixOS/nixpkgs/issues/150655
245 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
246 ];
247
248 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
249 fetchAttrs = {
250 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
251 # we have to force @mkl_dnn_v1 since it's not needed on darwin
252 bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
253 bazelFlags = bazelFlags ++ [
254 "--config=avx_posix"
255 ] ++ lib.optionals cudaSupport [
256 # ideally we'd add this unconditionally too, but it doesn't work on darwin
257 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
258 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
259 # have access to darwin machines
260 "--config=cuda"
261 ] ++ [
262 "--config=mkl_open_source_only"
263 ];
264
265 sha256 = (if cudaSupport then {
266 x86_64-linux = "sha256-QczClHxHElLZCqIZlHc3z3DXJ7rZQJaMs2XIb+lxarI=";
267 } else {
268 x86_64-linux = "sha256-mqiJe4u0NYh1PKCbQfbo0U2e9/kYiBqj98d+BPHFSxQ=";
269 aarch64-linux = "sha256-EuLqamVBJ+qoVMCFIYUT846AghltZolfLGdtO9UeXSM=";
270 }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
271 };
272
273 buildAttrs = {
274 outputs = [ "out" ];
275
276 TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
277 "nsync" # fails to build on darwin
278 ]);
279
280 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
281 # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
282 # loading multiple extensions in the same python program due to duplicate protobuf DBs.
283 # 2) Patch python path in the compiler driver.
284 preBuild = lib.optionalString cudaSupport ''
285 export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
286 patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
287 '' + lib.optionalString stdenv.isDarwin ''
288 # Framework search paths aren't added by bintools hook
289 # https://github.com/NixOS/nixpkgs/pull/41914
290 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
291 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
292 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
293 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
294 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
295 '';
296 };
297
298 inherit meta;
299 };
300 platformTag =
301 if stdenv.hostPlatform.isLinux then
302 "manylinux2014_${arch}"
303 else if stdenv.system == "x86_64-darwin" then
304 "macosx_10_9_${arch}"
305 else if stdenv.system == "aarch64-darwin" then
306 "macosx_11_0_${arch}"
307 else throw "Unsupported target platform: ${stdenv.hostPlatform}";
308
309in
310buildPythonPackage {
311 inherit meta pname version;
312 format = "wheel";
313
314 src =
315 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
316 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
317
318 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
319 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
320 # more info.
321 postInstall = lib.optionalString cudaSupport ''
322 mkdir -p $out/bin
323 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
324
325 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
326 addOpenGLRunpath "$lib"
327 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
328 done
329 '';
330
331 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
332
333 propagatedBuildInputs = [
334 absl-py
335 curl
336 double-conversion
337 flatbuffers
338 giflib
339 grpc
340 jsoncpp
341 libjpeg_turbo
342 ml-dtypes
343 numpy
344 scipy
345 six
346 snappy
347 ];
348
349 pythonImportsCheck = [
350 "jaxlib"
351 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
352 "jaxlib.cpu_feature_guard"
353 "jaxlib.xla_client"
354 ];
355
356 # Without it there are complaints about libcudart.so.11.0 not being found
357 # because RPATH path entries added above are stripped.
358 dontPatchELF = cudaSupport;
359}