1{ 2 lib, 3 pkgs, 4 stdenv, 5 6 # Build-time dependencies: 7 addOpenGLRunpath, 8 autoAddDriverRunpath, 9 bazel_6, 10 binutils, 11 buildBazelPackage, 12 buildPythonPackage, 13 cctools, 14 curl, 15 cython, 16 fetchFromGitHub, 17 git, 18 IOKit, 19 jsoncpp, 20 nsync, 21 openssl, 22 pybind11, 23 setuptools, 24 symlinkJoin, 25 wheel, 26 build, 27 which, 28 29 # Python dependencies: 30 absl-py, 31 flatbuffers, 32 ml-dtypes, 33 numpy, 34 scipy, 35 six, 36 37 # Runtime dependencies: 38 double-conversion, 39 giflib, 40 libjpeg_turbo, 41 python, 42 snappy, 43 zlib, 44 45 config, 46 # CUDA flags: 47 cudaSupport ? config.cudaSupport, 48 cudaPackages, 49 50 # MKL: 51 mklSupport ? true, 52}@inputs: 53 54let 55 inherit (cudaPackages) 56 cudaFlags 57 cudaVersion 58 cudnn 59 nccl 60 ; 61 62 pname = "jaxlib"; 63 version = "0.4.28"; 64 65 # It's necessary to consistently use backendStdenv when building with CUDA 66 # support, otherwise we get libstdc++ errors downstream 67 stdenv = throw "Use effectiveStdenv instead"; 68 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; 69 70 meta = with lib; { 71 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; 72 homepage = "https://github.com/google/jax"; 73 license = licenses.asl20; 74 maintainers = with maintainers; [ ndl ]; 75 platforms = platforms.unix; 76 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 77 # however even with that fix applied, it doesn't work for everyone: 78 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 79 # NOTE: We always build with NCCL; if it is unsupported, then our build is broken. 80 broken = effectiveStdenv.isDarwin || nccl.meta.unsupported; 81 }; 82 83 # These are necessary at build time and run time. 84 cuda_libs_joined = symlinkJoin { 85 name = "cuda-joined"; 86 paths = with cudaPackages; [ 87 cuda_cudart.lib # libcudart.so 88 cuda_cudart.static # libcudart_static.a 89 cuda_cupti.lib # libcupti.so 90 libcublas.lib # libcublas.so 91 libcufft.lib # libcufft.so 92 libcurand.lib # libcurand.so 93 libcusolver.lib # libcusolver.so 94 libcusparse.lib # libcusparse.so 95 ]; 96 }; 97 # These are only necessary at build time. 98 cuda_build_deps_joined = symlinkJoin { 99 name = "cuda-build-deps-joined"; 100 paths = with cudaPackages; [ 101 cuda_libs_joined 102 103 # Binaries 104 cudaPackages.cuda_nvcc.bin # nvcc 105 106 # Headers 107 cuda_cccl.dev # block_load.cuh 108 cuda_cudart.dev # cuda.h 109 cuda_cupti.dev # cupti.h 110 cuda_nvcc.dev # See https://github.com/google/jax/issues/19811 111 cuda_nvml_dev # nvml.h 112 cuda_nvtx.dev # nvToolsExt.h 113 libcublas.dev # cublas_api.h 114 libcufft.dev # cufft.h 115 libcurand.dev # curand.h 116 libcusolver.dev # cusolver_common.h 117 libcusparse.dev # cusparse.h 118 ]; 119 }; 120 121 backend_cc_joined = symlinkJoin { 122 name = "cuda-cc-joined"; 123 paths = [ 124 effectiveStdenv.cc 125 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 126 ]; 127 }; 128 129 # Copy-paste from TF derivation. 130 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 131 # 'as is' so that it's more compatible with TF derivation. 132 tf_system_libs = [ 133 "absl_py" 134 "astor_archive" 135 "astunparse_archive" 136 # Not packaged in nixpkgs 137 # "com_github_googleapis_googleapis" 138 # "com_github_googlecloudplatform_google_cloud_cpp" 139 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108 140 # "com_github_grpc_grpc" 141 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain': 142 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel 143 # "com_google_protobuf" 144 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 145 # "com_googlesource_code_re2" 146 "curl" 147 "cython" 148 "dill_archive" 149 "double_conversion" 150 "flatbuffers" 151 "functools32_archive" 152 "gast_archive" 153 "gif" 154 "hwloc" 155 "icu" 156 "jsoncpp_git" 157 "libjpeg_turbo" 158 "lmdb" 159 "nasm" 160 "opt_einsum_archive" 161 "org_sqlite" 162 "pasta" 163 "png" 164 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx': 165 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel 166 # "pybind11" 167 "six_archive" 168 "snappy" 169 "tblib_archive" 170 "termcolor_archive" 171 "typing_extensions_archive" 172 "wrapt" 173 "zlib" 174 ]; 175 176 arch = 177 # KeyError: ('Linux', 'arm64') 178 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then 179 "aarch64" 180 else 181 effectiveStdenv.hostPlatform.linuxArch; 182 183 xla = effectiveStdenv.mkDerivation { 184 pname = "xla-src"; 185 version = "unstable"; 186 187 src = fetchFromGitHub { 188 owner = "openxla"; 189 repo = "xla"; 190 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 191 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; 192 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; 193 }; 194 195 dontBuild = true; 196 197 # This is necessary for patchShebangs to know the right path to use. 198 nativeBuildInputs = [ python ]; 199 200 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl 201 postPatch = '' 202 patchShebangs . 203 ''; 204 205 installPhase = '' 206 cp -r . $out 207 ''; 208 }; 209 210 bazel-build = buildBazelPackage rec { 211 name = "bazel-build-${pname}-${version}"; 212 213 # See https://github.com/google/jax/blob/main/.bazelversion for the latest. 214 bazel = bazel_6; 215 216 src = fetchFromGitHub { 217 owner = "google"; 218 repo = "jax"; 219 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 220 rev = "refs/tags/${pname}-v${version}"; 221 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 222 }; 223 224 nativeBuildInputs = [ 225 cython 226 pkgs.flatbuffers 227 git 228 setuptools 229 wheel 230 build 231 which 232 ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ]; 233 234 buildInputs = 235 [ 236 curl 237 double-conversion 238 giflib 239 jsoncpp 240 libjpeg_turbo 241 numpy 242 openssl 243 pkgs.flatbuffers 244 pkgs.protobuf 245 pybind11 246 scipy 247 six 248 snappy 249 zlib 250 ] 251 ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ] 252 ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ]; 253 254 # We don't want to be quite so picky regarding bazel version 255 postPatch = '' 256 rm -f .bazelversion 257 ''; 258 259 bazelRunTarget = "//jaxlib/tools:build_wheel"; 260 runTargetFlags = [ 261 "--output_path=$out" 262 "--cpu=${arch}" 263 # This has no impact whatsoever... 264 "--jaxlib_git_hash='12345678'" 265 ]; 266 267 removeRulesCC = false; 268 269 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin"; 270 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc"; 271 272 # The version is automatically set to ".dev" if this variable is not set. 273 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 274 JAXLIB_RELEASE = "1"; 275 276 preConfigure = 277 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error 278 '' 279 mkdir dummy-ldconfig 280 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig 281 chmod +x dummy-ldconfig/ldconfig 282 export PATH="$PWD/dummy-ldconfig:$PATH" 283 '' 284 + 285 286 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345 287 # for more info. We assume 288 # * `cpu = None` 289 # * `enable_nccl = True` 290 # * `target_cpu_features = "release"` 291 # * `rocm_amdgpu_targets = None` 292 # * `enable_rocm = False` 293 # * `build_gpu_plugin = False` 294 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?) 295 # 296 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266 297 # instead of duplicating the logic here. Perhaps we can leverage the 298 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)? 299 '' 300 cat <<CFG > ./.jax_configure.bazelrc 301 build --strategy=Genrule=standalone 302 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 303 build --action_env=PYENV_ROOT 304 build --python_path="${python}/bin/python" 305 build --distinct_host_configuration=false 306 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" 307 '' 308 + lib.optionalString cudaSupport '' 309 build --config=cuda 310 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" 311 build --action_env CUDNN_INSTALL_PATH="${cudnn}" 312 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}" 313 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}" 314 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" 315 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" 316 '' 317 + 318 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just 319 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so 320 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322 321 # for upstream's version. 322 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) 323 '' 324 build --config=avx_posix 325 '' 326 + lib.optionalString mklSupport '' 327 build --config=mkl_open_source_only 328 '' 329 + '' 330 CFG 331 ''; 332 333 # Make sure Bazel knows about our configuration flags during fetching so that the 334 # relevant dependencies can be downloaded. 335 bazelFlags = 336 [ 337 "-c opt" 338 # See https://bazel.build/external/advanced#overriding-repositories for 339 # information on --override_repository flag. 340 "--override_repository=xla=${xla}" 341 ] 342 ++ lib.optionals effectiveStdenv.cc.isClang [ 343 # bazel depends on the compiler frontend automatically selecting these flags based on file 344 # extension but our clang doesn't. 345 # https://github.com/NixOS/nixpkgs/issues/150655 346 "--cxxopt=-x" 347 "--cxxopt=c++" 348 "--host_cxxopt=-x" 349 "--host_cxxopt=c++" 350 ]; 351 352 # We intentionally overfetch so we can share the fetch derivation across all the different configurations 353 fetchAttrs = { 354 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; 355 # we have to force @mkl_dnn_v1 since it's not needed on darwin 356 bazelTargets = [ 357 bazelRunTarget 358 "@mkl_dnn_v1//:mkl_dnn" 359 ]; 360 bazelFlags = 361 bazelFlags 362 ++ [ 363 "--config=avx_posix" 364 "--config=mkl_open_source_only" 365 ] 366 ++ lib.optionals cudaSupport [ 367 # ideally we'd add this unconditionally too, but it doesn't work on darwin 368 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both 369 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't 370 # have access to darwin machines 371 "--config=cuda" 372 ]; 373 374 sha256 = 375 ( 376 if cudaSupport then 377 { x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; } 378 else 379 { 380 x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ="; 381 aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA="; 382 } 383 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 384 }; 385 386 buildAttrs = { 387 outputs = [ "out" ]; 388 389 TF_SYSTEM_LIBS = lib.concatStringsSep "," ( 390 tf_system_libs 391 ++ lib.optionals (!effectiveStdenv.isDarwin) [ 392 "nsync" # fails to build on darwin 393 ] 394 ); 395 396 # Note: we cannot do most of this patching at `patch` phase as the deps 397 # are not available yet. Framework search paths aren't added by bintools 398 # hook. See https://github.com/NixOS/nixpkgs/pull/41914. 399 preBuild = lib.optionalString effectiveStdenv.isDarwin '' 400 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks" 401 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ 402 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" 403 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ 404 --replace "/usr/bin/libtool" "${cctools}/bin/libtool" 405 ''; 406 }; 407 408 inherit meta; 409 }; 410 platformTag = 411 if effectiveStdenv.hostPlatform.isLinux then 412 "manylinux2014_${arch}" 413 else if effectiveStdenv.system == "x86_64-darwin" then 414 "macosx_10_9_${arch}" 415 else if effectiveStdenv.system == "aarch64-darwin" then 416 "macosx_11_0_${arch}" 417 else 418 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; 419in 420buildPythonPackage { 421 inherit meta pname version; 422 format = "wheel"; 423 424 src = 425 let 426 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}"; 427 in 428 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 429 430 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 431 # for more info. 432 postInstall = lib.optionalString cudaSupport '' 433 mkdir -p $out/bin 434 ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas 435 436 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 437 patchelf --add-rpath "${ 438 lib.makeLibraryPath [ 439 cuda_libs_joined 440 cudnn 441 nccl 442 ] 443 }" "$lib" 444 done 445 ''; 446 447 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 448 449 dependencies = [ 450 absl-py 451 curl 452 double-conversion 453 flatbuffers 454 giflib 455 jsoncpp 456 libjpeg_turbo 457 ml-dtypes 458 numpy 459 scipy 460 six 461 snappy 462 ]; 463 464 pythonImportsCheck = [ 465 "jaxlib" 466 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. 467 "jaxlib.cpu_feature_guard" 468 "jaxlib.xla_client" 469 ]; 470 471 # Without it there are complaints about libcudart.so.11.0 not being found 472 # because RPATH path entries added above are stripped. 473 dontPatchELF = cudaSupport; 474}