nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{ lib
2, pkgs
3, stdenv
4
5 # Build-time dependencies:
6, addOpenGLRunpath
7, bazel_5
8, binutils
9, buildBazelPackage
10, buildPythonPackage
11, cython
12, fetchFromGitHub
13, git
14, jsoncpp
15, pybind11
16, setuptools
17, symlinkJoin
18, wheel
19, which
20
21 # Python dependencies:
22, absl-py
23, flatbuffers
24, numpy
25, scipy
26, six
27
28 # Runtime dependencies:
29, double-conversion
30, giflib
31, grpc
32, libjpeg_turbo
33, python
34, snappy
35, zlib
36
37 # CUDA flags:
38, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ]
39, cudaSupport ? false
40, cudaPackages ? {}
41
42 # MKL:
43, mklSupport ? true
44}:
45
46let
47
48 inherit (cudaPackages) cudatoolkit cudnn nccl;
49
50 pname = "jaxlib";
51 version = "0.3.0";
52
53 meta = with lib; {
54 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
55 homepage = "https://github.com/google/jax";
56 license = licenses.asl20;
57 maintainers = with maintainers; [ ndl ];
58 platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"];
59 hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds
60 };
61
62 cudatoolkit_joined = symlinkJoin {
63 name = "${cudatoolkit.name}-merged";
64 paths = [
65 cudatoolkit.lib
66 cudatoolkit.out
67 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
68 # for some reason some of the required libs are in the targets/x86_64-linux
69 # directory; not sure why but this works around it
70 "${cudatoolkit}/targets/${stdenv.system}"
71 ];
72 };
73
74 cudatoolkit_cc_joined = symlinkJoin {
75 name = "${cudatoolkit.cc.name}-merged";
76 paths = [
77 cudatoolkit.cc
78 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
79 ];
80 };
81
82 bazel-build = buildBazelPackage {
83 name = "bazel-build-${pname}-${version}";
84
85 bazel = bazel_5;
86
87 src = fetchFromGitHub {
88 owner = "google";
89 repo = "jax";
90 rev = "${pname}-v${version}";
91 sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
92 };
93
94 nativeBuildInputs = [
95 cython
96 pkgs.flatbuffers
97 git
98 setuptools
99 wheel
100 which
101 ];
102
103 buildInputs = [
104 double-conversion
105 giflib
106 grpc
107 jsoncpp
108 libjpeg_turbo
109 numpy
110 pkgs.flatbuffers
111 pkgs.protobuf
112 pybind11
113 scipy
114 six
115 snappy
116 zlib
117 ] ++ lib.optionals cudaSupport [
118 cudatoolkit
119 cudnn
120 ];
121
122 postPatch = ''
123 rm -f .bazelversion
124 '';
125
126 bazelTarget = "//build:build_wheel";
127
128 removeRulesCC = false;
129
130 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
131 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
132
133 preConfigure = ''
134 # dummy ldconfig
135 mkdir dummy-ldconfig
136 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
137 chmod +x dummy-ldconfig/ldconfig
138 export PATH="$PWD/dummy-ldconfig:$PATH"
139 cat <<CFG > ./.jax_configure.bazelrc
140 build --strategy=Genrule=standalone
141 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
142 build --action_env=PYENV_ROOT
143 build --python_path="${python}/bin/python"
144 build --distinct_host_configuration=false
145 '' + lib.optionalString cudaSupport ''
146 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
147 build --action_env CUDNN_INSTALL_PATH="${cudnn}"
148 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
149 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
150 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
151 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}"
152 '' + ''
153 CFG
154 '';
155
156 # Copy-paste from TF derivation.
157 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
158 # 'as is' so that it's more compatible with TF derivation.
159 TF_SYSTEM_LIBS = lib.concatStringsSep "," [
160 "absl_py"
161 "astor_archive"
162 "astunparse_archive"
163 "boringssl"
164 # Not packaged in nixpkgs
165 # "com_github_googleapis_googleapis"
166 # "com_github_googlecloudplatform_google_cloud_cpp"
167 "com_github_grpc_grpc"
168 "com_google_protobuf"
169 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
170 # "com_googlesource_code_re2"
171 "curl"
172 "cython"
173 "dill_archive"
174 "double_conversion"
175 "enum34_archive"
176 "flatbuffers"
177 "functools32_archive"
178 "gast_archive"
179 "gif"
180 "hwloc"
181 "icu"
182 "jsoncpp_git"
183 "libjpeg_turbo"
184 "lmdb"
185 "nasm"
186 # "nsync" # not packaged in nixpkgs
187 "opt_einsum_archive"
188 "org_sqlite"
189 "pasta"
190 "pcre"
191 "png"
192 "pybind11"
193 "six_archive"
194 "snappy"
195 "tblib_archive"
196 "termcolor_archive"
197 "typing_extensions_archive"
198 "wrapt"
199 "zlib"
200 ];
201
202 # Make sure Bazel knows about our configuration flags during fetching so that the
203 # relevant dependencies can be downloaded.
204 bazelFetchFlags = bazel-build.bazelBuildFlags;
205
206 bazelBuildFlags = [
207 "-c opt"
208 ] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
209 "--config=avx_posix"
210 ] ++ lib.optional cudaSupport [
211 "--config=cuda"
212 ] ++ lib.optional mklSupport [
213 "--config=mkl_open_source_only"
214 ];
215
216 fetchAttrs = {
217 sha256 =
218 if cudaSupport then
219 "0d2rqwk9n4a6c51m4g21rxymv85kw2sdksni30cdx3pdcdbqgic7"
220 else
221 "0q540mwmh7grig0qq48ynzqi0gynimxnrq7k97wribqpkx99k39d";
222 };
223
224 buildAttrs = {
225 outputs = [ "out" ];
226
227 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
228 # 1) Fix pybind11 include paths.
229 # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
230 # in the same python program due to duplicate protobuf DBs.
231 # 3) Patch python path in the compiler driver.
232 # 4) Patch tensorflow sources to work with later versions of protobuf. See
233 # https://github.com/google/jax/issues/9534. Note that this should be
234 # removed on the next release after 0.3.0.
235 preBuild = ''
236 for src in ./jaxlib/*.{cc,h}; do
237 sed -i 's@include/pybind11@pybind11@g' $src
238 done
239 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
240 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
241 substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
242 --replace "status.message()" "std::string{status.message()}"
243 '' + lib.optionalString cudaSupport ''
244 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
245 '';
246
247 installPhase = ''
248 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
249 '';
250 };
251
252 inherit meta;
253 };
254
255in
256buildPythonPackage {
257 inherit meta pname version;
258 format = "wheel";
259
260 src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
261
262 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
263 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
264 # more info.
265 postInstall = lib.optionalString cudaSupport ''
266 mkdir -p $out/bin
267 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
268
269 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
270 addOpenGLRunpath "$lib"
271 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
272 done
273 '';
274
275 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
276
277 propagatedBuildInputs = [
278 absl-py
279 double-conversion
280 flatbuffers
281 giflib
282 grpc
283 jsoncpp
284 libjpeg_turbo
285 numpy
286 scipy
287 six
288 snappy
289 ];
290
291 pythonImportsCheck = [ "jaxlib" ];
292
293 # Without it there are complaints about libcudart.so.11.0 not being found
294 # because RPATH path entries added above are stripped.
295 dontPatchELF = cudaSupport;
296}