1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_5
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cctools
12, curl
13, cython
14, fetchFromGitHub
15, git
16, IOKit
17, jsoncpp
18, nsync
19, openssl
20, pybind11
21, setuptools
22, symlinkJoin
23, wheel
24, which
25
26 # Python dependencies:
27, absl-py
28, flatbuffers
29, numpy
30, scipy
31, six
32
33 # Runtime dependencies:
34, double-conversion
35, giflib
36, grpc
37, libjpeg_turbo
38, protobuf
39, python
40, snappy
41, zlib
42
43 # CUDA flags:
44, cudaSupport ? false
45, cudaPackages ? {}
46
47 # MKL:
48, mklSupport ? true
49}:
50
51let
52 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
53
54 pname = "jaxlib";
55 version = "0.4.4";
56
57 meta = with lib; {
58 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
59 homepage = "https://github.com/google/jax";
60 license = licenses.asl20;
61 maintainers = with maintainers; [ ndl ];
62 platforms = platforms.unix;
63 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
64 # however even with that fix applied, it doesn't work for everyone:
65 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
66 broken = stdenv.isDarwin;
67 };
68
69 cudatoolkit_joined = symlinkJoin {
70 name = "${cudatoolkit.name}-merged";
71 paths = [
72 cudatoolkit.lib
73 cudatoolkit.out
74 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
75 # for some reason some of the required libs are in the targets/x86_64-linux
76 # directory; not sure why but this works around it
77 "${cudatoolkit}/targets/${stdenv.system}"
78 ];
79 };
80
81 cudatoolkit_cc_joined = symlinkJoin {
82 name = "${cudatoolkit.cc.name}-merged";
83 paths = [
84 backendStdenv.cc
85 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
86 ];
87 };
88
89 # Copy-paste from TF derivation.
90 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
91 # 'as is' so that it's more compatible with TF derivation.
92 tf_system_libs = [
93 "absl_py"
94 "astor_archive"
95 "astunparse_archive"
96 "boringssl"
97 # Not packaged in nixpkgs
98 # "com_github_googleapis_googleapis"
99 # "com_github_googlecloudplatform_google_cloud_cpp"
100 "com_github_grpc_grpc"
101 "com_google_protobuf"
102 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
103 # "com_googlesource_code_re2"
104 "curl"
105 "cython"
106 "dill_archive"
107 "double_conversion"
108 "flatbuffers"
109 "functools32_archive"
110 "gast_archive"
111 "gif"
112 "hwloc"
113 "icu"
114 "jsoncpp_git"
115 "libjpeg_turbo"
116 "lmdb"
117 "nasm"
118 "opt_einsum_archive"
119 "org_sqlite"
120 "pasta"
121 "png"
122 "pybind11"
123 "six_archive"
124 "snappy"
125 "tblib_archive"
126 "termcolor_archive"
127 "typing_extensions_archive"
128 "wrapt"
129 "zlib"
130 ];
131
132 arch =
133 # KeyError: ('Linux', 'arm64')
134 if stdenv.targetPlatform.isLinux && stdenv.targetPlatform.linuxArch == "arm64" then "aarch64"
135 else stdenv.targetPlatform.linuxArch;
136
137 bazel-build = buildBazelPackage rec {
138 name = "bazel-build-${pname}-${version}";
139
140 bazel = bazel_5;
141
142 src = fetchFromGitHub {
143 owner = "google";
144 repo = "jax";
145 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
146 rev = "refs/tags/${pname}-v${version}";
147 hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
148 };
149
150 nativeBuildInputs = [
151 cython
152 pkgs.flatbuffers
153 git
154 setuptools
155 wheel
156 which
157 ] ++ lib.optionals stdenv.isDarwin [
158 cctools
159 ];
160
161 buildInputs = [
162 curl
163 double-conversion
164 giflib
165 grpc
166 jsoncpp
167 libjpeg_turbo
168 numpy
169 openssl
170 pkgs.flatbuffers
171 protobuf
172 pybind11
173 scipy
174 six
175 snappy
176 zlib
177 ] ++ lib.optionals cudaSupport [
178 cudatoolkit
179 cudnn
180 ] ++ lib.optionals stdenv.isDarwin [
181 IOKit
182 ] ++ lib.optionals (!stdenv.isDarwin) [
183 nsync
184 ];
185
186 postPatch = ''
187 rm -f .bazelversion
188 '';
189
190 bazelTargets = [ "//build:build_wheel" ];
191
192 removeRulesCC = false;
193
194 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
195 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
196
197 preConfigure = ''
198 # dummy ldconfig
199 mkdir dummy-ldconfig
200 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
201 chmod +x dummy-ldconfig/ldconfig
202 export PATH="$PWD/dummy-ldconfig:$PATH"
203 cat <<CFG > ./.jax_configure.bazelrc
204 build --strategy=Genrule=standalone
205 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
206 build --action_env=PYENV_ROOT
207 build --python_path="${python}/bin/python"
208 build --distinct_host_configuration=false
209 build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
210 '' + lib.optionalString cudaSupport ''
211 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
212 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
213 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
214 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
215 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
216 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
217 '' + ''
218 CFG
219 '';
220
221 # Make sure Bazel knows about our configuration flags during fetching so that the
222 # relevant dependencies can be downloaded.
223 bazelFlags = [
224 "-c opt"
225 ] ++ lib.optionals stdenv.cc.isClang [
226 # bazel depends on the compiler frontend automatically selecting these flags based on file
227 # extension but our clang doesn't.
228 # https://github.com/NixOS/nixpkgs/issues/150655
229 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
230 ];
231
232 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
233 fetchAttrs = {
234 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
235 # we have to force @mkl_dnn_v1 since it's not needed on darwin
236 bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
237 bazelFlags = bazelFlags ++ [
238 "--config=avx_posix"
239 ] ++ lib.optionals cudaSupport [
240 # ideally we'd add this unconditionally too, but it doesn't work on darwin
241 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
242 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
243 # have access to darwin machines
244 "--config=cuda"
245 ] ++ [
246 "--config=mkl_open_source_only"
247 ];
248
249 sha256 =
250 if cudaSupport then
251 "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
252 else
253 "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
254 };
255
256 buildAttrs = {
257 outputs = [ "out" ];
258
259 TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
260 "nsync" # fails to build on darwin
261 ]);
262
263 bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
264 "--config=avx_posix"
265 ] ++ lib.optionals cudaSupport [
266 "--config=cuda"
267 ] ++ lib.optionals mklSupport [
268 "--config=mkl_open_source_only"
269 ];
270 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
271 # 1) Fix pybind11 include paths.
272 # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
273 # loading multiple extensions in the same python program due to duplicate protobuf DBs.
274 # 3) Patch python path in the compiler driver.
275 preBuild = ''
276 for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
277 sed -i 's@include/pybind11@pybind11@g' $src
278 done
279 '' + lib.optionalString cudaSupport ''
280 export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
281 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
282 '' + lib.optionalString stdenv.isDarwin ''
283 # Framework search paths aren't added by bintools hook
284 # https://github.com/NixOS/nixpkgs/pull/41914
285 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
286 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
287 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
288 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
289 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
290 '' + (if stdenv.cc.isGNU then ''
291 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
292 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
293 '' else if stdenv.cc.isClang then ''
294 sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
295 sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
296 '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
297
298 installPhase = ''
299 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
300 '';
301 };
302
303 inherit meta;
304 };
305 platformTag =
306 if stdenv.targetPlatform.isLinux then
307 "manylinux2014_${arch}"
308 else if stdenv.system == "x86_64-darwin" then
309 "macosx_10_9_${arch}"
310 else if stdenv.system == "aarch64-darwin" then
311 "macosx_11_0_${arch}"
312 else throw "Unsupported target platform: ${stdenv.targetPlatform}";
313
314in
315buildPythonPackage {
316 inherit meta pname version;
317 format = "wheel";
318
319 src =
320 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
321 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
322
323 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
324 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
325 # more info.
326 postInstall = lib.optionalString cudaSupport ''
327 mkdir -p $out/bin
328 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
329
330 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
331 addOpenGLRunpath "$lib"
332 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
333 done
334 '';
335
336 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
337
338 propagatedBuildInputs = [
339 absl-py
340 curl
341 double-conversion
342 flatbuffers
343 giflib
344 grpc
345 jsoncpp
346 libjpeg_turbo
347 numpy
348 scipy
349 six
350 snappy
351 ];
352
353 pythonImportsCheck = [ "jaxlib" ];
354
355 # Without it there are complaints about libcudart.so.11.0 not being found
356 # because RPATH path entries added above are stripped.
357 dontPatchELF = cudaSupport;
358}