Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_5
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cctools
12, curl
13, cython
14, fetchFromGitHub
15, git
16, IOKit
17, jsoncpp
18, nsync
19, openssl
20, pybind11
21, setuptools
22, symlinkJoin
23, wheel
24, which
25
26 # Python dependencies:
27, absl-py
28, flatbuffers
29, numpy
30, scipy
31, six
32
33 # Runtime dependencies:
34, double-conversion
35, giflib
36, grpc
37, libjpeg_turbo
38, protobuf
39, python
40, snappy
41, zlib
42
43 # CUDA flags:
44, cudaSupport ? false
45, cudaPackages ? {}
46
47 # MKL:
48, mklSupport ? true
49}:
50
51let
52 inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
53
54 pname = "jaxlib";
55 version = "0.3.22";
56
57 meta = with lib; {
58 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
59 homepage = "https://github.com/google/jax";
60 license = licenses.asl20;
61 maintainers = with maintainers; [ ndl ];
62 platforms = platforms.unix;
63 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
64 # however even with that fix applied, it doesn't work for everyone:
65 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
66 broken = stdenv.isAarch64;
67 };
68
69 cudatoolkit_joined = symlinkJoin {
70 name = "${cudatoolkit.name}-merged";
71 paths = [
72 cudatoolkit.lib
73 cudatoolkit.out
74 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
75 # for some reason some of the required libs are in the targets/x86_64-linux
76 # directory; not sure why but this works around it
77 "${cudatoolkit}/targets/${stdenv.system}"
78 ];
79 };
80
81 cudatoolkit_cc_joined = symlinkJoin {
82 name = "${cudatoolkit.cc.name}-merged";
83 paths = [
84 cudatoolkit.cc
85 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
86 ];
87 };
88
89 bazel-build = buildBazelPackage {
90 name = "bazel-build-${pname}-${version}";
91
92 bazel = bazel_5;
93
94 src = fetchFromGitHub {
95 owner = "google";
96 repo = "jax";
97 rev = "${pname}-v${version}";
98 hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
99 };
100
101 nativeBuildInputs = [
102 cython
103 pkgs.flatbuffers
104 git
105 setuptools
106 wheel
107 which
108 ] ++ lib.optionals stdenv.isDarwin [
109 cctools
110 ];
111
112 buildInputs = [
113 curl
114 double-conversion
115 giflib
116 grpc
117 jsoncpp
118 libjpeg_turbo
119 numpy
120 openssl
121 pkgs.flatbuffers
122 protobuf
123 pybind11
124 scipy
125 six
126 snappy
127 zlib
128 ] ++ lib.optionals cudaSupport [
129 cudatoolkit
130 cudnn
131 ] ++ lib.optionals stdenv.isDarwin [
132 IOKit
133 ] ++ lib.optionals (!stdenv.isDarwin) [
134 nsync
135 ];
136
137 postPatch = ''
138 rm -f .bazelversion
139 '';
140
141 bazelTarget = "//build:build_wheel";
142
143 removeRulesCC = false;
144
145 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
146 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
147
148 preConfigure = ''
149 # dummy ldconfig
150 mkdir dummy-ldconfig
151 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
152 chmod +x dummy-ldconfig/ldconfig
153 export PATH="$PWD/dummy-ldconfig:$PATH"
154 cat <<CFG > ./.jax_configure.bazelrc
155 build --strategy=Genrule=standalone
156 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
157 build --action_env=PYENV_ROOT
158 build --python_path="${python}/bin/python"
159 build --distinct_host_configuration=false
160 build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
161 '' + lib.optionalString cudaSupport ''
162 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
163 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
164 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
165 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
166 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
167 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${cudaFlags.cudaRealCapabilitiesCommaString}"
168 '' + ''
169 CFG
170 '';
171
172 # Copy-paste from TF derivation.
173 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
174 # 'as is' so that it's more compatible with TF derivation.
175 TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
176 "absl_py"
177 "astor_archive"
178 "astunparse_archive"
179 "boringssl"
180 # Not packaged in nixpkgs
181 # "com_github_googleapis_googleapis"
182 # "com_github_googlecloudplatform_google_cloud_cpp"
183 "com_github_grpc_grpc"
184 "com_google_protobuf"
185 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
186 # "com_googlesource_code_re2"
187 "curl"
188 "cython"
189 "dill_archive"
190 "double_conversion"
191 "flatbuffers"
192 "functools32_archive"
193 "gast_archive"
194 "gif"
195 "hwloc"
196 "icu"
197 "jsoncpp_git"
198 "libjpeg_turbo"
199 "lmdb"
200 "nasm"
201 "opt_einsum_archive"
202 "org_sqlite"
203 "pasta"
204 "png"
205 "pybind11"
206 "six_archive"
207 "snappy"
208 "tblib_archive"
209 "termcolor_archive"
210 "typing_extensions_archive"
211 "wrapt"
212 "zlib"
213 ] ++ lib.optionals (!stdenv.isDarwin) [
214 "nsync" # fails to build on darwin
215 ]);
216
217 # Make sure Bazel knows about our configuration flags during fetching so that the
218 # relevant dependencies can be downloaded.
219 bazelFlags = [
220 "-c opt"
221 ] ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
222 "--config=avx_posix"
223 ] ++ lib.optionals cudaSupport [
224 "--config=cuda"
225 ] ++ lib.optionals mklSupport [
226 "--config=mkl_open_source_only"
227 ] ++ lib.optionals stdenv.cc.isClang [
228 # bazel depends on the compiler frontend automatically selecting these flags based on file
229 # extension but our clang doesn't.
230 # https://github.com/NixOS/nixpkgs/issues/150655
231 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
232 ];
233
234 fetchAttrs = {
235 sha256 =
236 if cudaSupport then
237 "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ="
238 else if stdenv.isDarwin then
239 "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM="
240 else
241 "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM=";
242 };
243
244 buildAttrs = {
245 outputs = [ "out" ];
246
247 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
248 # 1) Fix pybind11 include paths.
249 # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
250 # loading multiple extensions in the same python program due to duplicate protobuf DBs.
251 # 3) Patch python path in the compiler driver.
252 preBuild = ''
253 for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
254 sed -i 's@include/pybind11@pybind11@g' $src
255 done
256 '' + lib.optionalString cudaSupport ''
257 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
258 '' + lib.optionalString stdenv.isDarwin ''
259 # Framework search paths aren't added by bintools hook
260 # https://github.com/NixOS/nixpkgs/pull/41914
261 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
262 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
263 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
264 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
265 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
266 '' + (if stdenv.cc.isGNU then ''
267 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
268 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
269 '' else if stdenv.cc.isClang then ''
270 sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
271 sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
272 '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
273
274 installPhase = ''
275 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
276 '';
277 };
278
279 inherit meta;
280 };
281 platformTag =
282 if stdenv.targetPlatform.isLinux then
283 "manylinux2014_${stdenv.targetPlatform.linuxArch}"
284 else if stdenv.system == "x86_64-darwin" then
285 "macosx_10_9_${stdenv.targetPlatform.linuxArch}"
286 else if stdenv.system == "aarch64-darwin" then
287 "macosx_11_0_${stdenv.targetPlatform.linuxArch}"
288 else throw "Unsupported target platform: ${stdenv.targetPlatform}";
289
290in
291buildPythonPackage {
292 inherit meta pname version;
293 format = "wheel";
294
295 src =
296 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
297 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
298
299 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
300 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
301 # more info.
302 postInstall = lib.optionalString cudaSupport ''
303 mkdir -p $out/bin
304 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
305
306 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
307 addOpenGLRunpath "$lib"
308 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
309 done
310 '';
311
312 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
313
314 propagatedBuildInputs = [
315 absl-py
316 curl
317 double-conversion
318 flatbuffers
319 giflib
320 grpc
321 jsoncpp
322 libjpeg_turbo
323 numpy
324 scipy
325 six
326 snappy
327 ];
328
329 pythonImportsCheck = [ "jaxlib" ];
330
331 # Without it there are complaints about libcudart.so.11.0 not being found
332 # because RPATH path entries added above are stripped.
333 dontPatchELF = cudaSupport;
334}