1{ lib 2, pkgs 3, stdenv 4 5 # Build-time dependencies: 6, addOpenGLRunpath 7, bazel_5 8, binutils 9, buildBazelPackage 10, buildPythonPackage 11, cctools 12, curl 13, cython 14, fetchFromGitHub 15, git 16, IOKit 17, jsoncpp 18, nsync 19, openssl 20, pybind11 21, setuptools 22, symlinkJoin 23, wheel 24, which 25 26 # Python dependencies: 27, absl-py 28, flatbuffers 29, numpy 30, scipy 31, six 32 33 # Runtime dependencies: 34, double-conversion 35, giflib 36, grpc 37, libjpeg_turbo 38, protobuf 39, python 40, snappy 41, zlib 42 43 # CUDA flags: 44, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ] 45, cudaSupport ? false 46, cudaPackages ? {} 47 48 # MKL: 49, mklSupport ? true 50}: 51 52let 53 inherit (cudaPackages) cudatoolkit cudnn nccl; 54 55 pname = "jaxlib"; 56 version = "0.3.22"; 57 58 meta = with lib; { 59 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; 60 homepage = "https://github.com/google/jax"; 61 license = licenses.asl20; 62 maintainers = with maintainers; [ ndl ]; 63 platforms = platforms.unix; 64 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 65 # however even with that fix applied, it doesn't work for everyone: 66 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 67 broken = stdenv.isAarch64; 68 }; 69 70 cudatoolkit_joined = symlinkJoin { 71 name = "${cudatoolkit.name}-merged"; 72 paths = [ 73 cudatoolkit.lib 74 cudatoolkit.out 75 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [ 76 # for some reason some of the required libs are in the targets/x86_64-linux 77 # directory; not sure why but this works around it 78 "${cudatoolkit}/targets/${stdenv.system}" 79 ]; 80 }; 81 82 cudatoolkit_cc_joined = symlinkJoin { 83 name = "${cudatoolkit.cc.name}-merged"; 84 paths = [ 85 cudatoolkit.cc 86 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 87 ]; 88 }; 89 90 bazel-build = buildBazelPackage { 91 name = "bazel-build-${pname}-${version}"; 92 93 bazel = bazel_5; 94 95 src = fetchFromGitHub { 96 owner = "google"; 97 repo = "jax"; 98 rev = "${pname}-v${version}"; 99 hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs="; 100 }; 101 102 nativeBuildInputs = [ 103 cython 104 pkgs.flatbuffers 105 git 106 setuptools 107 wheel 108 which 109 ] ++ lib.optionals stdenv.isDarwin [ 110 cctools 111 ]; 112 113 buildInputs = [ 114 curl 115 double-conversion 116 giflib 117 grpc 118 jsoncpp 119 libjpeg_turbo 120 numpy 121 openssl 122 pkgs.flatbuffers 123 protobuf 124 pybind11 125 scipy 126 six 127 snappy 128 zlib 129 ] ++ lib.optionals cudaSupport [ 130 cudatoolkit 131 cudnn 132 ] ++ lib.optionals stdenv.isDarwin [ 133 IOKit 134 ] ++ lib.optionals (!stdenv.isDarwin) [ 135 nsync 136 ]; 137 138 postPatch = '' 139 rm -f .bazelversion 140 ''; 141 142 bazelTarget = "//build:build_wheel"; 143 144 removeRulesCC = false; 145 146 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin"; 147 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc"; 148 149 preConfigure = '' 150 # dummy ldconfig 151 mkdir dummy-ldconfig 152 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig 153 chmod +x dummy-ldconfig/ldconfig 154 export PATH="$PWD/dummy-ldconfig:$PATH" 155 cat <<CFG > ./.jax_configure.bazelrc 156 build --strategy=Genrule=standalone 157 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 158 build --action_env=PYENV_ROOT 159 build --python_path="${python}/bin/python" 160 build --distinct_host_configuration=false 161 build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" 162 '' + lib.optionalString cudaSupport '' 163 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" 164 build --action_env CUDNN_INSTALL_PATH="${cudnn}" 165 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}" 166 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}" 167 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" 168 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}" 169 '' + '' 170 CFG 171 ''; 172 173 # Copy-paste from TF derivation. 174 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 175 # 'as is' so that it's more compatible with TF derivation. 176 TF_SYSTEM_LIBS = lib.concatStringsSep "," ([ 177 "absl_py" 178 "astor_archive" 179 "astunparse_archive" 180 "boringssl" 181 # Not packaged in nixpkgs 182 # "com_github_googleapis_googleapis" 183 # "com_github_googlecloudplatform_google_cloud_cpp" 184 "com_github_grpc_grpc" 185 "com_google_protobuf" 186 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 187 # "com_googlesource_code_re2" 188 "curl" 189 "cython" 190 "dill_archive" 191 "double_conversion" 192 "flatbuffers" 193 "functools32_archive" 194 "gast_archive" 195 "gif" 196 "hwloc" 197 "icu" 198 "jsoncpp_git" 199 "libjpeg_turbo" 200 "lmdb" 201 "nasm" 202 "opt_einsum_archive" 203 "org_sqlite" 204 "pasta" 205 "png" 206 "pybind11" 207 "six_archive" 208 "snappy" 209 "tblib_archive" 210 "termcolor_archive" 211 "typing_extensions_archive" 212 "wrapt" 213 "zlib" 214 ] ++ lib.optionals (!stdenv.isDarwin) [ 215 "nsync" # fails to build on darwin 216 ]); 217 218 # Make sure Bazel knows about our configuration flags during fetching so that the 219 # relevant dependencies can be downloaded. 220 bazelFlags = [ 221 "-c opt" 222 ] ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ 223 "--config=avx_posix" 224 ] ++ lib.optionals cudaSupport [ 225 "--config=cuda" 226 ] ++ lib.optionals mklSupport [ 227 "--config=mkl_open_source_only" 228 ] ++ lib.optionals stdenv.cc.isClang [ 229 # bazel depends on the compiler frontend automatically selecting these flags based on file 230 # extension but our clang doesn't. 231 # https://github.com/NixOS/nixpkgs/issues/150655 232 "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++" 233 ]; 234 235 fetchAttrs = { 236 sha256 = 237 if cudaSupport then 238 "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ=" 239 else if stdenv.isDarwin then 240 "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM=" 241 else 242 "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM="; 243 }; 244 245 buildAttrs = { 246 outputs = [ "out" ]; 247 248 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. 249 # 1) Fix pybind11 include paths. 250 # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on 251 # loading multiple extensions in the same python program due to duplicate protobuf DBs. 252 # 3) Patch python path in the compiler driver. 253 preBuild = '' 254 for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do 255 sed -i 's@include/pybind11@pybind11@g' $src 256 done 257 '' + lib.optionalString cudaSupport '' 258 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl 259 '' + lib.optionalString stdenv.isDarwin '' 260 # Framework search paths aren't added by bintools hook 261 # https://github.com/NixOS/nixpkgs/pull/41914 262 export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks" 263 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ 264 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" 265 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ 266 --replace "/usr/bin/libtool" "${cctools}/bin/libtool" 267 '' + (if stdenv.cc.isGNU then '' 268 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 269 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 270 '' else if stdenv.cc.isClang then '' 271 sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 272 sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 273 '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); 274 275 installPhase = '' 276 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch} 277 ''; 278 }; 279 280 inherit meta; 281 }; 282 platformTag = 283 if stdenv.targetPlatform.isLinux then 284 "manylinux2014_${stdenv.targetPlatform.linuxArch}" 285 else if stdenv.system == "x86_64-darwin" then 286 "macosx_10_9_${stdenv.targetPlatform.linuxArch}" 287 else if stdenv.system == "aarch64-darwin" then 288 "macosx_11_0_${stdenv.targetPlatform.linuxArch}" 289 else throw "Unsupported target platform: ${stdenv.targetPlatform}"; 290 291in 292buildPythonPackage { 293 inherit meta pname version; 294 format = "wheel"; 295 296 src = 297 let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; 298 in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 299 300 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. 301 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for 302 # more info. 303 postInstall = lib.optionalString cudaSupport '' 304 mkdir -p $out/bin 305 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas 306 307 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 308 addOpenGLRunpath "$lib" 309 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" 310 done 311 ''; 312 313 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath; 314 315 propagatedBuildInputs = [ 316 absl-py 317 curl 318 double-conversion 319 flatbuffers 320 giflib 321 grpc 322 jsoncpp 323 libjpeg_turbo 324 numpy 325 scipy 326 six 327 snappy 328 ]; 329 330 pythonImportsCheck = [ "jaxlib" ]; 331 332 # Without it there are complaints about libcudart.so.11.0 not being found 333 # because RPATH path entries added above are stripped. 334 dontPatchELF = cudaSupport; 335}