Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 pkgs, 4 stdenv, 5 6 # Build-time dependencies: 7 addDriverRunpath, 8 autoAddDriverRunpath, 9 bazel_6, 10 binutils, 11 buildBazelPackage, 12 buildPythonPackage, 13 cctools, 14 curl, 15 cython, 16 fetchFromGitHub, 17 git, 18 IOKit, 19 jsoncpp, 20 nsync, 21 openssl, 22 pybind11, 23 setuptools, 24 symlinkJoin, 25 wheel, 26 build, 27 which, 28 29 # Python dependencies: 30 absl-py, 31 flatbuffers, 32 ml-dtypes, 33 numpy, 34 scipy, 35 six, 36 37 # Runtime dependencies: 38 double-conversion, 39 giflib, 40 libjpeg_turbo, 41 python, 42 snappy, 43 zlib, 44 45 config, 46 # CUDA flags: 47 cudaSupport ? config.cudaSupport, 48 cudaPackages, 49 50 # MKL: 51 mklSupport ? true, 52}@inputs: 53 54let 55 inherit (cudaPackages) 56 cudaFlags 57 cudaVersion 58 nccl 59 ; 60 61 pname = "jaxlib"; 62 version = "0.4.28"; 63 64 # It's necessary to consistently use backendStdenv when building with CUDA 65 # support, otherwise we get libstdc++ errors downstream 66 stdenv = throw "Use effectiveStdenv instead"; 67 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; 68 69 meta = with lib; { 70 description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research"; 71 homepage = "https://github.com/google/jax"; 72 license = licenses.asl20; 73 maintainers = with maintainers; [ ndl ]; 74 75 # Make this platforms.unix once Darwin is supported. 76 # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms. 77 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 78 # however even with that fix applied, it doesn't work for everyone: 79 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 80 platforms = platforms.linux; 81 }; 82 83 # Bazel wants a merged cudnn at configuration time 84 cudnnMerged = symlinkJoin { 85 name = "cudnn-merged"; 86 paths = with cudaPackages; [ 87 (lib.getDev cudnn) 88 (lib.getLib cudnn) 89 ]; 90 }; 91 92 # These are necessary at build time and run time. 93 cuda_libs_joined = symlinkJoin { 94 name = "cuda-joined"; 95 paths = with cudaPackages; [ 96 (lib.getLib cuda_cudart) # libcudart.so 97 (lib.getLib cuda_cupti) # libcupti.so 98 (lib.getLib libcublas) # libcublas.so 99 (lib.getLib libcufft) # libcufft.so 100 (lib.getLib libcurand) # libcurand.so 101 (lib.getLib libcusolver) # libcusolver.so 102 (lib.getLib libcusparse) # libcusparse.so 103 ]; 104 }; 105 # These are only necessary at build time. 106 cuda_build_deps_joined = symlinkJoin { 107 name = "cuda-build-deps-joined"; 108 paths = with cudaPackages; [ 109 cuda_libs_joined 110 111 # Binaries 112 (lib.getBin cuda_nvcc) # nvcc 113 114 # Archives 115 (lib.getOutput "static" cuda_cudart) # libcudart_static.a 116 117 # Headers 118 (lib.getDev cuda_cccl) # block_load.cuh 119 (lib.getDev cuda_cudart) # cuda.h 120 (lib.getDev cuda_cupti) # cupti.h 121 (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811 122 (lib.getDev cuda_nvml_dev) # nvml.h 123 (lib.getDev cuda_nvtx) # nvToolsExt.h 124 (lib.getDev libcublas) # cublas_api.h 125 (lib.getDev libcufft) # cufft.h 126 (lib.getDev libcurand) # curand.h 127 (lib.getDev libcusolver) # cusolver_common.h 128 (lib.getDev libcusparse) # cusparse.h 129 ]; 130 }; 131 132 backend_cc_joined = symlinkJoin { 133 name = "cuda-cc-joined"; 134 paths = [ 135 effectiveStdenv.cc 136 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 137 ]; 138 }; 139 140 # Copy-paste from TF derivation. 141 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 142 # 'as is' so that it's more compatible with TF derivation. 143 tf_system_libs = [ 144 "absl_py" 145 "astor_archive" 146 "astunparse_archive" 147 # Not packaged in nixpkgs 148 # "com_github_googleapis_googleapis" 149 # "com_github_googlecloudplatform_google_cloud_cpp" 150 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108 151 # "com_github_grpc_grpc" 152 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain': 153 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel 154 # "com_google_protobuf" 155 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 156 # "com_googlesource_code_re2" 157 "curl" 158 "cython" 159 "dill_archive" 160 "double_conversion" 161 "flatbuffers" 162 "functools32_archive" 163 "gast_archive" 164 "gif" 165 "hwloc" 166 "icu" 167 "jsoncpp_git" 168 "libjpeg_turbo" 169 "lmdb" 170 "nasm" 171 "opt_einsum_archive" 172 "org_sqlite" 173 "pasta" 174 "png" 175 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx': 176 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel 177 # "pybind11" 178 "six_archive" 179 "snappy" 180 "tblib_archive" 181 "termcolor_archive" 182 "typing_extensions_archive" 183 "wrapt" 184 "zlib" 185 ]; 186 187 arch = 188 # KeyError: ('Linux', 'arm64') 189 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then 190 "aarch64" 191 else 192 effectiveStdenv.hostPlatform.linuxArch; 193 194 xla = effectiveStdenv.mkDerivation { 195 pname = "xla-src"; 196 version = "unstable"; 197 198 src = fetchFromGitHub { 199 owner = "openxla"; 200 repo = "xla"; 201 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 202 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; 203 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; 204 }; 205 206 dontBuild = true; 207 208 # This is necessary for patchShebangs to know the right path to use. 209 nativeBuildInputs = [ python ]; 210 211 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl 212 postPatch = '' 213 patchShebangs . 214 ''; 215 216 installPhase = '' 217 cp -r . $out 218 ''; 219 }; 220 221 bazel-build = buildBazelPackage rec { 222 name = "bazel-build-${pname}-${version}"; 223 224 # See https://github.com/google/jax/blob/main/.bazelversion for the latest. 225 bazel = bazel_6; 226 227 src = fetchFromGitHub { 228 owner = "google"; 229 repo = "jax"; 230 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 231 rev = "refs/tags/${pname}-v${version}"; 232 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 233 }; 234 235 nativeBuildInputs = [ 236 cython 237 pkgs.flatbuffers 238 git 239 setuptools 240 wheel 241 build 242 which 243 ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ]; 244 245 buildInputs = 246 [ 247 curl 248 double-conversion 249 giflib 250 jsoncpp 251 libjpeg_turbo 252 numpy 253 openssl 254 pkgs.flatbuffers 255 pkgs.protobuf 256 pybind11 257 scipy 258 six 259 snappy 260 zlib 261 ] 262 ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ] 263 ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ]; 264 265 # We don't want to be quite so picky regarding bazel version 266 postPatch = '' 267 rm -f .bazelversion 268 ''; 269 270 bazelRunTarget = "//jaxlib/tools:build_wheel"; 271 runTargetFlags = [ 272 "--output_path=$out" 273 "--cpu=${arch}" 274 # This has no impact whatsoever... 275 "--jaxlib_git_hash='12345678'" 276 ]; 277 278 removeRulesCC = false; 279 280 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin"; 281 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc"; 282 283 # The version is automatically set to ".dev" if this variable is not set. 284 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 285 JAXLIB_RELEASE = "1"; 286 287 preConfigure = 288 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error 289 '' 290 mkdir dummy-ldconfig 291 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig 292 chmod +x dummy-ldconfig/ldconfig 293 export PATH="$PWD/dummy-ldconfig:$PATH" 294 '' 295 + 296 297 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345 298 # for more info. We assume 299 # * `cpu = None` 300 # * `enable_nccl = True` 301 # * `target_cpu_features = "release"` 302 # * `rocm_amdgpu_targets = None` 303 # * `enable_rocm = False` 304 # * `build_gpu_plugin = False` 305 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?) 306 # 307 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266 308 # instead of duplicating the logic here. Perhaps we can leverage the 309 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)? 310 '' 311 cat <<CFG > ./.jax_configure.bazelrc 312 build --strategy=Genrule=standalone 313 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 314 build --action_env=PYENV_ROOT 315 build --python_path="${python}/bin/python" 316 build --distinct_host_configuration=false 317 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" 318 '' 319 + lib.optionalString cudaSupport '' 320 build --config=cuda 321 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" 322 build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}" 323 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}" 324 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}" 325 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}" 326 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" 327 '' 328 + 329 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just 330 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so 331 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322 332 # for upstream's version. 333 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) 334 '' 335 build --config=avx_posix 336 '' 337 + lib.optionalString mklSupport '' 338 build --config=mkl_open_source_only 339 '' 340 + '' 341 CFG 342 ''; 343 344 # Make sure Bazel knows about our configuration flags during fetching so that the 345 # relevant dependencies can be downloaded. 346 bazelFlags = 347 [ 348 "-c opt" 349 # See https://bazel.build/external/advanced#overriding-repositories for 350 # information on --override_repository flag. 351 "--override_repository=xla=${xla}" 352 ] 353 ++ lib.optionals effectiveStdenv.cc.isClang [ 354 # bazel depends on the compiler frontend automatically selecting these flags based on file 355 # extension but our clang doesn't. 356 # https://github.com/NixOS/nixpkgs/issues/150655 357 "--cxxopt=-x" 358 "--cxxopt=c++" 359 "--host_cxxopt=-x" 360 "--host_cxxopt=c++" 361 ]; 362 363 # We intentionally overfetch so we can share the fetch derivation across all the different configurations 364 fetchAttrs = { 365 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; 366 # we have to force @mkl_dnn_v1 since it's not needed on darwin 367 bazelTargets = [ 368 bazelRunTarget 369 "@mkl_dnn_v1//:mkl_dnn" 370 ]; 371 bazelFlags = 372 bazelFlags 373 ++ [ 374 "--config=avx_posix" 375 "--config=mkl_open_source_only" 376 ] 377 ++ lib.optionals cudaSupport [ 378 # ideally we'd add this unconditionally too, but it doesn't work on darwin 379 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both 380 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't 381 # have access to darwin machines 382 "--config=cuda" 383 ]; 384 385 sha256 = 386 ( 387 if cudaSupport then 388 { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; } 389 else 390 { 391 x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw="; 392 aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM="; 393 } 394 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 395 396 # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546 397 preInstall = '' 398 cat << \EOF > "$bazelOut/external/go_sdk/versions.json" 399 [] 400 EOF 401 ''; 402 }; 403 404 buildAttrs = { 405 outputs = [ "out" ]; 406 407 TF_SYSTEM_LIBS = lib.concatStringsSep "," ( 408 tf_system_libs 409 ++ lib.optionals (!effectiveStdenv.isDarwin) [ 410 "nsync" # fails to build on darwin 411 ] 412 ); 413 414 # Note: we cannot do most of this patching at `patch` phase as the deps 415 # are not available yet. Framework search paths aren't added by bintools 416 # hook. See https://github.com/NixOS/nixpkgs/pull/41914. 417 preBuild = lib.optionalString effectiveStdenv.isDarwin '' 418 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks" 419 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ 420 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" 421 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ 422 --replace "/usr/bin/libtool" "${cctools}/bin/libtool" 423 ''; 424 }; 425 426 inherit meta; 427 }; 428 platformTag = 429 if effectiveStdenv.hostPlatform.isLinux then 430 "manylinux2014_${arch}" 431 else if effectiveStdenv.system == "x86_64-darwin" then 432 "macosx_10_9_${arch}" 433 else if effectiveStdenv.system == "aarch64-darwin" then 434 "macosx_11_0_${arch}" 435 else 436 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; 437in 438buildPythonPackage { 439 inherit pname version; 440 format = "wheel"; 441 442 src = 443 let 444 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}"; 445 in 446 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 447 448 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 449 # for more info. 450 postInstall = lib.optionalString cudaSupport '' 451 mkdir -p $out/bin 452 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas 453 454 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 455 patchelf --add-rpath "${ 456 lib.makeLibraryPath [ 457 cuda_libs_joined 458 (lib.getLib cudaPackages.cudnn) 459 nccl 460 ] 461 }" "$lib" 462 done 463 ''; 464 465 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 466 467 dependencies = [ 468 absl-py 469 curl 470 double-conversion 471 flatbuffers 472 giflib 473 jsoncpp 474 libjpeg_turbo 475 ml-dtypes 476 numpy 477 scipy 478 six 479 snappy 480 ]; 481 482 pythonImportsCheck = [ 483 "jaxlib" 484 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. 485 "jaxlib.cpu_feature_guard" 486 "jaxlib.xla_client" 487 ]; 488 489 # Without it there are complaints about libcudart.so.11.0 not being found 490 # because RPATH path entries added above are stripped. 491 dontPatchELF = cudaSupport; 492 493 passthru = { 494 # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps` 495 inherit bazel-build; 496 }; 497 498 inherit meta; 499}