1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_5
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cctools
12, curl
13, cython
14, fetchFromGitHub
15, git
16, IOKit
17, jsoncpp
18, nsync
19, openssl
20, pybind11
21, setuptools
22, symlinkJoin
23, wheel
24, which
25
26 # Python dependencies:
27, absl-py
28, flatbuffers
29, numpy
30, scipy
31, six
32
33 # Runtime dependencies:
34, double-conversion
35, giflib
36, grpc
37, libjpeg_turbo
38, protobuf
39, python
40, snappy
41, zlib
42
43 # CUDA flags:
44, cudaSupport ? false
45, cudaPackages ? {}
46
47 # MKL:
48, mklSupport ? true
49}:
50
51let
52 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
53
54 pname = "jaxlib";
55 version = "0.4.4";
56
57 meta = with lib; {
58 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
59 homepage = "https://github.com/google/jax";
60 license = licenses.asl20;
61 maintainers = with maintainers; [ ndl ];
62 platforms = platforms.unix;
63 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
64 # however even with that fix applied, it doesn't work for everyone:
65 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
66 broken = stdenv.isAarch64 || stdenv.isDarwin;
67 };
68
69 cudatoolkit_joined = symlinkJoin {
70 name = "${cudatoolkit.name}-merged";
71 paths = [
72 cudatoolkit.lib
73 cudatoolkit.out
74 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
75 # for some reason some of the required libs are in the targets/x86_64-linux
76 # directory; not sure why but this works around it
77 "${cudatoolkit}/targets/${stdenv.system}"
78 ];
79 };
80
81 cudatoolkit_cc_joined = symlinkJoin {
82 name = "${cudatoolkit.cc.name}-merged";
83 paths = [
84 backendStdenv.cc
85 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
86 ];
87 };
88
89 # Copy-paste from TF derivation.
90 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
91 # 'as is' so that it's more compatible with TF derivation.
92 tf_system_libs = [
93 "absl_py"
94 "astor_archive"
95 "astunparse_archive"
96 "boringssl"
97 # Not packaged in nixpkgs
98 # "com_github_googleapis_googleapis"
99 # "com_github_googlecloudplatform_google_cloud_cpp"
100 "com_github_grpc_grpc"
101 "com_google_protobuf"
102 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
103 # "com_googlesource_code_re2"
104 "curl"
105 "cython"
106 "dill_archive"
107 "double_conversion"
108 "flatbuffers"
109 "functools32_archive"
110 "gast_archive"
111 "gif"
112 "hwloc"
113 "icu"
114 "jsoncpp_git"
115 "libjpeg_turbo"
116 "lmdb"
117 "nasm"
118 "opt_einsum_archive"
119 "org_sqlite"
120 "pasta"
121 "png"
122 "pybind11"
123 "six_archive"
124 "snappy"
125 "tblib_archive"
126 "termcolor_archive"
127 "typing_extensions_archive"
128 "wrapt"
129 "zlib"
130 ];
131
132 bazel-build = buildBazelPackage rec {
133 name = "bazel-build-${pname}-${version}";
134
135 bazel = bazel_5;
136
137 src = fetchFromGitHub {
138 owner = "google";
139 repo = "jax";
140 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
141 rev = "refs/tags/${pname}-v${version}";
142 hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
143 };
144
145 nativeBuildInputs = [
146 cython
147 pkgs.flatbuffers
148 git
149 setuptools
150 wheel
151 which
152 ] ++ lib.optionals stdenv.isDarwin [
153 cctools
154 ];
155
156 buildInputs = [
157 curl
158 double-conversion
159 giflib
160 grpc
161 jsoncpp
162 libjpeg_turbo
163 numpy
164 openssl
165 pkgs.flatbuffers
166 protobuf
167 pybind11
168 scipy
169 six
170 snappy
171 zlib
172 ] ++ lib.optionals cudaSupport [
173 cudatoolkit
174 cudnn
175 ] ++ lib.optionals stdenv.isDarwin [
176 IOKit
177 ] ++ lib.optionals (!stdenv.isDarwin) [
178 nsync
179 ];
180
181 postPatch = ''
182 rm -f .bazelversion
183 '';
184
185 bazelTargets = [ "//build:build_wheel" ];
186
187 removeRulesCC = false;
188
189 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
190 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
191
192 preConfigure = ''
193 # dummy ldconfig
194 mkdir dummy-ldconfig
195 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
196 chmod +x dummy-ldconfig/ldconfig
197 export PATH="$PWD/dummy-ldconfig:$PATH"
198 cat <<CFG > ./.jax_configure.bazelrc
199 build --strategy=Genrule=standalone
200 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
201 build --action_env=PYENV_ROOT
202 build --python_path="${python}/bin/python"
203 build --distinct_host_configuration=false
204 build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
205 '' + lib.optionalString cudaSupport ''
206 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
207 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
208 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
209 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
210 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
211 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
212 '' + ''
213 CFG
214 '';
215
216 # Make sure Bazel knows about our configuration flags during fetching so that the
217 # relevant dependencies can be downloaded.
218 bazelFlags = [
219 "-c opt"
220 ] ++ lib.optionals stdenv.cc.isClang [
221 # bazel depends on the compiler frontend automatically selecting these flags based on file
222 # extension but our clang doesn't.
223 # https://github.com/NixOS/nixpkgs/issues/150655
224 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
225 ];
226
227 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
228 fetchAttrs = {
229 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
230 # we have to force @mkl_dnn_v1 since it's not needed on darwin
231 bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
232 bazelFlags = bazelFlags ++ [
233 "--config=avx_posix"
234 ] ++ lib.optionals cudaSupport [
235 # ideally we'd add this unconditionally too, but it doesn't work on darwin
236 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
237 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
238 # have access to darwin machines
239 "--config=cuda"
240 ] ++ [
241 "--config=mkl_open_source_only"
242 ];
243
244 sha256 =
245 if cudaSupport then
246 "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
247 else
248 "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
249 };
250
251 buildAttrs = {
252 outputs = [ "out" ];
253
254 TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
255 "nsync" # fails to build on darwin
256 ]);
257
258 bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
259 "--config=avx_posix"
260 ] ++ lib.optionals cudaSupport [
261 "--config=cuda"
262 ] ++ lib.optionals mklSupport [
263 "--config=mkl_open_source_only"
264 ];
265 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
266 # 1) Fix pybind11 include paths.
267 # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
268 # loading multiple extensions in the same python program due to duplicate protobuf DBs.
269 # 3) Patch python path in the compiler driver.
270 preBuild = ''
271 for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
272 sed -i 's@include/pybind11@pybind11@g' $src
273 done
274 '' + lib.optionalString cudaSupport ''
275 export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
276 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
277 '' + lib.optionalString stdenv.isDarwin ''
278 # Framework search paths aren't added by bintools hook
279 # https://github.com/NixOS/nixpkgs/pull/41914
280 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
281 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
282 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
283 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
284 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
285 '' + (if stdenv.cc.isGNU then ''
286 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
287 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
288 '' else if stdenv.cc.isClang then ''
289 sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
290 sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
291 '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
292
293 installPhase = ''
294 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
295 '';
296 };
297
298 inherit meta;
299 };
300 platformTag =
301 if stdenv.targetPlatform.isLinux then
302 "manylinux2014_${stdenv.targetPlatform.linuxArch}"
303 else if stdenv.system == "x86_64-darwin" then
304 "macosx_10_9_${stdenv.targetPlatform.linuxArch}"
305 else if stdenv.system == "aarch64-darwin" then
306 "macosx_11_0_${stdenv.targetPlatform.linuxArch}"
307 else throw "Unsupported target platform: ${stdenv.targetPlatform}";
308
309in
310buildPythonPackage {
311 inherit meta pname version;
312 format = "wheel";
313
314 src =
315 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
316 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
317
318 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
319 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
320 # more info.
321 postInstall = lib.optionalString cudaSupport ''
322 mkdir -p $out/bin
323 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
324
325 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
326 addOpenGLRunpath "$lib"
327 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
328 done
329 '';
330
331 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
332
333 propagatedBuildInputs = [
334 absl-py
335 curl
336 double-conversion
337 flatbuffers
338 giflib
339 grpc
340 jsoncpp
341 libjpeg_turbo
342 numpy
343 scipy
344 six
345 snappy
346 ];
347
348 pythonImportsCheck = [ "jaxlib" ];
349
350 # Without it there are complaints about libcudart.so.11.0 not being found
351 # because RPATH path entries added above are stripped.
352 dontPatchELF = cudaSupport;
353}