1{ lib 2, pkgs 3, stdenv 4 5 # Build-time dependencies: 6, addOpenGLRunpath 7, bazel_6 8, binutils 9, buildBazelPackage 10, buildPythonPackage 11, cctools 12, curl 13, cython 14, fetchFromGitHub 15, git 16, IOKit 17, jsoncpp 18, nsync 19, openssl 20, pybind11 21, setuptools 22, symlinkJoin 23, wheel 24, build 25, which 26 27 # Python dependencies: 28, absl-py 29, flatbuffers 30, ml-dtypes 31, numpy 32, scipy 33, six 34 35 # Runtime dependencies: 36, double-conversion 37, giflib 38, grpc 39, libjpeg_turbo 40, python 41, snappy 42, zlib 43 44, config 45 # CUDA flags: 46, cudaSupport ? config.cudaSupport 47, cudaPackages ? {} 48 49 # MKL: 50, mklSupport ? true 51}: 52 53let 54 inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; 55 56 pname = "jaxlib"; 57 version = "0.4.20"; 58 59 meta = with lib; { 60 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; 61 homepage = "https://github.com/google/jax"; 62 license = licenses.asl20; 63 maintainers = with maintainers; [ ndl ]; 64 platforms = platforms.unix; 65 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 66 # however even with that fix applied, it doesn't work for everyone: 67 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 68 broken = stdenv.isDarwin; 69 }; 70 71 cudatoolkit_joined = symlinkJoin { 72 name = "${cudatoolkit.name}-merged"; 73 paths = [ 74 cudatoolkit.lib 75 cudatoolkit.out 76 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [ 77 # for some reason some of the required libs are in the targets/x86_64-linux 78 # directory; not sure why but this works around it 79 "${cudatoolkit}/targets/${stdenv.system}" 80 ]; 81 }; 82 83 cudatoolkit_cc_joined = symlinkJoin { 84 name = "${cudatoolkit.cc.name}-merged"; 85 paths = [ 86 backendStdenv.cc 87 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 88 ]; 89 }; 90 91 # Copy-paste from TF derivation. 92 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 93 # 'as is' so that it's more compatible with TF derivation. 94 tf_system_libs = [ 95 "absl_py" 96 "astor_archive" 97 "astunparse_archive" 98 # Not packaged in nixpkgs 99 # "com_github_googleapis_googleapis" 100 # "com_github_googlecloudplatform_google_cloud_cpp" 101 "com_github_grpc_grpc" 102 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain': 103 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel 104 # "com_google_protobuf" 105 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 106 # "com_googlesource_code_re2" 107 "curl" 108 "cython" 109 "dill_archive" 110 "double_conversion" 111 "flatbuffers" 112 "functools32_archive" 113 "gast_archive" 114 "gif" 115 "hwloc" 116 "icu" 117 "jsoncpp_git" 118 "libjpeg_turbo" 119 "lmdb" 120 "nasm" 121 "opt_einsum_archive" 122 "org_sqlite" 123 "pasta" 124 "png" 125 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx': 126 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel 127 # "pybind11" 128 "six_archive" 129 "snappy" 130 "tblib_archive" 131 "termcolor_archive" 132 "typing_extensions_archive" 133 "wrapt" 134 "zlib" 135 ]; 136 137 arch = 138 # KeyError: ('Linux', 'arm64') 139 if stdenv.hostPlatform.isLinux && stdenv.hostPlatform.linuxArch == "arm64" then "aarch64" 140 else stdenv.hostPlatform.linuxArch; 141 142 bazel-build = buildBazelPackage rec { 143 name = "bazel-build-${pname}-${version}"; 144 145 # See https://github.com/google/jax/blob/main/.bazelversion for the latest. 146 bazel = bazel_6; 147 148 src = fetchFromGitHub { 149 owner = "google"; 150 repo = "jax"; 151 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 152 rev = "refs/tags/${pname}-v${version}"; 153 hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU="; 154 }; 155 156 nativeBuildInputs = [ 157 cython 158 pkgs.flatbuffers 159 git 160 setuptools 161 wheel 162 build 163 which 164 ] ++ lib.optionals stdenv.isDarwin [ 165 cctools 166 ]; 167 168 buildInputs = [ 169 curl 170 double-conversion 171 giflib 172 grpc 173 jsoncpp 174 libjpeg_turbo 175 numpy 176 openssl 177 pkgs.flatbuffers 178 pkgs.protobuf 179 pybind11 180 scipy 181 six 182 snappy 183 zlib 184 ] ++ lib.optionals cudaSupport [ 185 cudatoolkit 186 cudnn 187 ] ++ lib.optionals stdenv.isDarwin [ 188 IOKit 189 ] ++ lib.optionals (!stdenv.isDarwin) [ 190 nsync 191 ]; 192 193 postPatch = '' 194 rm -f .bazelversion 195 ''; 196 197 bazelRunTarget = "//jaxlib/tools:build_wheel"; 198 runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; 199 200 removeRulesCC = false; 201 202 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin"; 203 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc"; 204 205 # The version is automatically set to ".dev" if this variable is not set. 206 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 207 JAXLIB_RELEASE = "1"; 208 209 preConfigure = '' 210 # dummy ldconfig 211 mkdir dummy-ldconfig 212 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig 213 chmod +x dummy-ldconfig/ldconfig 214 export PATH="$PWD/dummy-ldconfig:$PATH" 215 cat <<CFG > ./.jax_configure.bazelrc 216 build --strategy=Genrule=standalone 217 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 218 build --action_env=PYENV_ROOT 219 build --python_path="${python}/bin/python" 220 build --distinct_host_configuration=false 221 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" 222 '' + lib.optionalString (stdenv.hostPlatform.avxSupport && stdenv.hostPlatform.isUnix) '' 223 build --config=avx_posix 224 '' + lib.optionalString mklSupport '' 225 build --config=mkl_open_source_only 226 '' + lib.optionalString cudaSupport '' 227 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" 228 build --action_env CUDNN_INSTALL_PATH="${cudnn}" 229 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}" 230 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}" 231 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" 232 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" 233 '' + '' 234 CFG 235 ''; 236 237 # Make sure Bazel knows about our configuration flags during fetching so that the 238 # relevant dependencies can be downloaded. 239 bazelFlags = [ 240 "-c opt" 241 ] ++ lib.optionals stdenv.cc.isClang [ 242 # bazel depends on the compiler frontend automatically selecting these flags based on file 243 # extension but our clang doesn't. 244 # https://github.com/NixOS/nixpkgs/issues/150655 245 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++" 246 ]; 247 248 # We intentionally overfetch so we can share the fetch derivation across all the different configurations 249 fetchAttrs = { 250 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; 251 # we have to force @mkl_dnn_v1 since it's not needed on darwin 252 bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; 253 bazelFlags = bazelFlags ++ [ 254 "--config=avx_posix" 255 ] ++ lib.optionals cudaSupport [ 256 # ideally we'd add this unconditionally too, but it doesn't work on darwin 257 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both 258 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't 259 # have access to darwin machines 260 "--config=cuda" 261 ] ++ [ 262 "--config=mkl_open_source_only" 263 ]; 264 265 sha256 = (if cudaSupport then { 266 x86_64-linux = "sha256-QczClHxHElLZCqIZlHc3z3DXJ7rZQJaMs2XIb+lxarI="; 267 } else { 268 x86_64-linux = "sha256-mqiJe4u0NYh1PKCbQfbo0U2e9/kYiBqj98d+BPHFSxQ="; 269 aarch64-linux = "sha256-EuLqamVBJ+qoVMCFIYUT846AghltZolfLGdtO9UeXSM="; 270 }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}"); 271 }; 272 273 buildAttrs = { 274 outputs = [ "out" ]; 275 276 TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [ 277 "nsync" # fails to build on darwin 278 ]); 279 280 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. 281 # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on 282 # loading multiple extensions in the same python program due to duplicate protobuf DBs. 283 # 2) Patch python path in the compiler driver. 284 preBuild = lib.optionalString cudaSupport '' 285 export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" 286 patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl 287 '' + lib.optionalString stdenv.isDarwin '' 288 # Framework search paths aren't added by bintools hook 289 # https://github.com/NixOS/nixpkgs/pull/41914 290 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks" 291 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ 292 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" 293 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ 294 --replace "/usr/bin/libtool" "${cctools}/bin/libtool" 295 ''; 296 }; 297 298 inherit meta; 299 }; 300 platformTag = 301 if stdenv.hostPlatform.isLinux then 302 "manylinux2014_${arch}" 303 else if stdenv.system == "x86_64-darwin" then 304 "macosx_10_9_${arch}" 305 else if stdenv.system == "aarch64-darwin" then 306 "macosx_11_0_${arch}" 307 else throw "Unsupported target platform: ${stdenv.hostPlatform}"; 308 309in 310buildPythonPackage { 311 inherit meta pname version; 312 format = "wheel"; 313 314 src = 315 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; 316 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 317 318 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. 319 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for 320 # more info. 321 postInstall = lib.optionalString cudaSupport '' 322 mkdir -p $out/bin 323 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas 324 325 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 326 addOpenGLRunpath "$lib" 327 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" 328 done 329 ''; 330 331 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath; 332 333 propagatedBuildInputs = [ 334 absl-py 335 curl 336 double-conversion 337 flatbuffers 338 giflib 339 grpc 340 jsoncpp 341 libjpeg_turbo 342 ml-dtypes 343 numpy 344 scipy 345 six 346 snappy 347 ]; 348 349 pythonImportsCheck = [ 350 "jaxlib" 351 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. 352 "jaxlib.cpu_feature_guard" 353 "jaxlib.xla_client" 354 ]; 355 356 # Without it there are complaints about libcudart.so.11.0 not being found 357 # because RPATH path entries added above are stripped. 358 dontPatchELF = cudaSupport; 359}