1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_5
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cctools
12, curl
13, cython
14, fetchFromGitHub
15, git
16, IOKit
17, jsoncpp
18, nsync
19, openssl
20, pybind11
21, setuptools
22, symlinkJoin
23, wheel
24, which
25
26 # Python dependencies:
27, absl-py
28, flatbuffers
29, numpy
30, scipy
31, six
32
33 # Runtime dependencies:
34, double-conversion
35, giflib
36, grpc
37, libjpeg_turbo
38, protobuf
39, python
40, snappy
41, zlib
42
43 # CUDA flags:
44, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ]
45, cudaSupport ? false
46, cudaPackages ? {}
47
48 # MKL:
49, mklSupport ? true
50}:
51
52let
53 inherit (cudaPackages) cudatoolkit cudnn nccl;
54
55 pname = "jaxlib";
56 version = "0.3.22";
57
58 meta = with lib; {
59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
60 homepage = "https://github.com/google/jax";
61 license = licenses.asl20;
62 maintainers = with maintainers; [ ndl ];
63 platforms = platforms.unix;
64 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
65 # however even with that fix applied, it doesn't work for everyone:
66 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
67 broken = stdenv.isAarch64;
68 };
69
70 cudatoolkit_joined = symlinkJoin {
71 name = "${cudatoolkit.name}-merged";
72 paths = [
73 cudatoolkit.lib
74 cudatoolkit.out
75 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
76 # for some reason some of the required libs are in the targets/x86_64-linux
77 # directory; not sure why but this works around it
78 "${cudatoolkit}/targets/${stdenv.system}"
79 ];
80 };
81
82 cudatoolkit_cc_joined = symlinkJoin {
83 name = "${cudatoolkit.cc.name}-merged";
84 paths = [
85 cudatoolkit.cc
86 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
87 ];
88 };
89
90 bazel-build = buildBazelPackage {
91 name = "bazel-build-${pname}-${version}";
92
93 bazel = bazel_5;
94
95 src = fetchFromGitHub {
96 owner = "google";
97 repo = "jax";
98 rev = "${pname}-v${version}";
99 hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
100 };
101
102 nativeBuildInputs = [
103 cython
104 pkgs.flatbuffers
105 git
106 setuptools
107 wheel
108 which
109 ] ++ lib.optionals stdenv.isDarwin [
110 cctools
111 ];
112
113 buildInputs = [
114 curl
115 double-conversion
116 giflib
117 grpc
118 jsoncpp
119 libjpeg_turbo
120 numpy
121 openssl
122 pkgs.flatbuffers
123 protobuf
124 pybind11
125 scipy
126 six
127 snappy
128 zlib
129 ] ++ lib.optionals cudaSupport [
130 cudatoolkit
131 cudnn
132 ] ++ lib.optionals stdenv.isDarwin [
133 IOKit
134 ] ++ lib.optionals (!stdenv.isDarwin) [
135 nsync
136 ];
137
138 postPatch = ''
139 rm -f .bazelversion
140 '';
141
142 bazelTarget = "//build:build_wheel";
143
144 removeRulesCC = false;
145
146 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
147 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
148
149 preConfigure = ''
150 # dummy ldconfig
151 mkdir dummy-ldconfig
152 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
153 chmod +x dummy-ldconfig/ldconfig
154 export PATH="$PWD/dummy-ldconfig:$PATH"
155 cat <<CFG > ./.jax_configure.bazelrc
156 build --strategy=Genrule=standalone
157 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
158 build --action_env=PYENV_ROOT
159 build --python_path="${python}/bin/python"
160 build --distinct_host_configuration=false
161 build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
162 '' + lib.optionalString cudaSupport ''
163 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
164 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
165 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
166 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
167 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
168 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}"
169 '' + ''
170 CFG
171 '';
172
173 # Copy-paste from TF derivation.
174 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
175 # 'as is' so that it's more compatible with TF derivation.
176 TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
177 "absl_py"
178 "astor_archive"
179 "astunparse_archive"
180 "boringssl"
181 # Not packaged in nixpkgs
182 # "com_github_googleapis_googleapis"
183 # "com_github_googlecloudplatform_google_cloud_cpp"
184 "com_github_grpc_grpc"
185 "com_google_protobuf"
186 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
187 # "com_googlesource_code_re2"
188 "curl"
189 "cython"
190 "dill_archive"
191 "double_conversion"
192 "flatbuffers"
193 "functools32_archive"
194 "gast_archive"
195 "gif"
196 "hwloc"
197 "icu"
198 "jsoncpp_git"
199 "libjpeg_turbo"
200 "lmdb"
201 "nasm"
202 "opt_einsum_archive"
203 "org_sqlite"
204 "pasta"
205 "png"
206 "pybind11"
207 "six_archive"
208 "snappy"
209 "tblib_archive"
210 "termcolor_archive"
211 "typing_extensions_archive"
212 "wrapt"
213 "zlib"
214 ] ++ lib.optionals (!stdenv.isDarwin) [
215 "nsync" # fails to build on darwin
216 ]);
217
218 # Make sure Bazel knows about our configuration flags during fetching so that the
219 # relevant dependencies can be downloaded.
220 bazelFlags = [
221 "-c opt"
222 ] ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
223 "--config=avx_posix"
224 ] ++ lib.optionals cudaSupport [
225 "--config=cuda"
226 ] ++ lib.optionals mklSupport [
227 "--config=mkl_open_source_only"
228 ] ++ lib.optionals stdenv.cc.isClang [
229 # bazel depends on the compiler frontend automatically selecting these flags based on file
230 # extension but our clang doesn't.
231 # https://github.com/NixOS/nixpkgs/issues/150655
232 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
233 ];
234
235 fetchAttrs = {
236 sha256 =
237 if cudaSupport then
238 "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ="
239 else if stdenv.isDarwin then
240 "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM="
241 else
242 "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM=";
243 };
244
245 buildAttrs = {
246 outputs = [ "out" ];
247
248 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
249 # 1) Fix pybind11 include paths.
250 # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
251 # loading multiple extensions in the same python program due to duplicate protobuf DBs.
252 # 3) Patch python path in the compiler driver.
253 preBuild = ''
254 for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
255 sed -i 's@include/pybind11@pybind11@g' $src
256 done
257 '' + lib.optionalString cudaSupport ''
258 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
259 '' + lib.optionalString stdenv.isDarwin ''
260 # Framework search paths aren't added by bintools hook
261 # https://github.com/NixOS/nixpkgs/pull/41914
262 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
263 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
264 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
265 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
266 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
267 '' + (if stdenv.cc.isGNU then ''
268 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
269 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
270 '' else if stdenv.cc.isClang then ''
271 sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
272 sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
273 '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
274
275 installPhase = ''
276 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
277 '';
278 };
279
280 inherit meta;
281 };
282 platformTag =
283 if stdenv.targetPlatform.isLinux then
284 "manylinux2014_${stdenv.targetPlatform.linuxArch}"
285 else if stdenv.system == "x86_64-darwin" then
286 "macosx_10_9_${stdenv.targetPlatform.linuxArch}"
287 else if stdenv.system == "aarch64-darwin" then
288 "macosx_11_0_${stdenv.targetPlatform.linuxArch}"
289 else throw "Unsupported target platform: ${stdenv.targetPlatform}";
290
291in
292buildPythonPackage {
293 inherit meta pname version;
294 format = "wheel";
295
296 src =
297 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
298 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
299
300 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
301 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
302 # more info.
303 postInstall = lib.optionalString cudaSupport ''
304 mkdir -p $out/bin
305 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
306
307 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
308 addOpenGLRunpath "$lib"
309 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
310 done
311 '';
312
313 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
314
315 propagatedBuildInputs = [
316 absl-py
317 curl
318 double-conversion
319 flatbuffers
320 giflib
321 grpc
322 jsoncpp
323 libjpeg_turbo
324 numpy
325 scipy
326 six
327 snappy
328 ];
329
330 pythonImportsCheck = [ "jaxlib" ];
331
332 # Without it there are complaints about libcudart.so.11.0 not being found
333 # because RPATH path entries added above are stripped.
334 dontPatchELF = cudaSupport;
335}