nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at 22.05 296 lines 8.6 kB view raw
1{ lib 2, pkgs 3, stdenv 4 5 # Build-time dependencies: 6, addOpenGLRunpath 7, bazel_5 8, binutils 9, buildBazelPackage 10, buildPythonPackage 11, cython 12, fetchFromGitHub 13, git 14, jsoncpp 15, pybind11 16, setuptools 17, symlinkJoin 18, wheel 19, which 20 21 # Python dependencies: 22, absl-py 23, flatbuffers 24, numpy 25, scipy 26, six 27 28 # Runtime dependencies: 29, double-conversion 30, giflib 31, grpc 32, libjpeg_turbo 33, python 34, snappy 35, zlib 36 37 # CUDA flags: 38, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ] 39, cudaSupport ? false 40, cudaPackages ? {} 41 42 # MKL: 43, mklSupport ? true 44}: 45 46let 47 48 inherit (cudaPackages) cudatoolkit cudnn nccl; 49 50 pname = "jaxlib"; 51 version = "0.3.0"; 52 53 meta = with lib; { 54 description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; 55 homepage = "https://github.com/google/jax"; 56 license = licenses.asl20; 57 maintainers = with maintainers; [ ndl ]; 58 platforms = [ "x86_64-linux" "aarch64-darwin" "x86_64-darwin"]; 59 hydraPlatforms = ["x86_64-linux" ]; # Don't think anybody is checking the darwin builds 60 }; 61 62 cudatoolkit_joined = symlinkJoin { 63 name = "${cudatoolkit.name}-merged"; 64 paths = [ 65 cudatoolkit.lib 66 cudatoolkit.out 67 ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [ 68 # for some reason some of the required libs are in the targets/x86_64-linux 69 # directory; not sure why but this works around it 70 "${cudatoolkit}/targets/${stdenv.system}" 71 ]; 72 }; 73 74 cudatoolkit_cc_joined = symlinkJoin { 75 name = "${cudatoolkit.cc.name}-merged"; 76 paths = [ 77 cudatoolkit.cc 78 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 79 ]; 80 }; 81 82 bazel-build = buildBazelPackage { 83 name = "bazel-build-${pname}-${version}"; 84 85 bazel = bazel_5; 86 87 src = fetchFromGitHub { 88 owner = "google"; 89 repo = "jax"; 90 rev = "${pname}-v${version}"; 91 sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72"; 92 }; 93 94 nativeBuildInputs = [ 95 cython 96 pkgs.flatbuffers 97 git 98 setuptools 99 wheel 100 which 101 ]; 102 103 buildInputs = [ 104 double-conversion 105 giflib 106 grpc 107 jsoncpp 108 libjpeg_turbo 109 numpy 110 pkgs.flatbuffers 111 pkgs.protobuf 112 pybind11 113 scipy 114 six 115 snappy 116 zlib 117 ] ++ lib.optionals cudaSupport [ 118 cudatoolkit 119 cudnn 120 ]; 121 122 postPatch = '' 123 rm -f .bazelversion 124 ''; 125 126 bazelTarget = "//build:build_wheel"; 127 128 removeRulesCC = false; 129 130 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin"; 131 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc"; 132 133 preConfigure = '' 134 # dummy ldconfig 135 mkdir dummy-ldconfig 136 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig 137 chmod +x dummy-ldconfig/ldconfig 138 export PATH="$PWD/dummy-ldconfig:$PATH" 139 cat <<CFG > ./.jax_configure.bazelrc 140 build --strategy=Genrule=standalone 141 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 142 build --action_env=PYENV_ROOT 143 build --python_path="${python}/bin/python" 144 build --distinct_host_configuration=false 145 '' + lib.optionalString cudaSupport '' 146 build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" 147 build --action_env CUDNN_INSTALL_PATH="${cudnn}" 148 build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}" 149 build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}" 150 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" 151 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}" 152 '' + '' 153 CFG 154 ''; 155 156 # Copy-paste from TF derivation. 157 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 158 # 'as is' so that it's more compatible with TF derivation. 159 TF_SYSTEM_LIBS = lib.concatStringsSep "," [ 160 "absl_py" 161 "astor_archive" 162 "astunparse_archive" 163 "boringssl" 164 # Not packaged in nixpkgs 165 # "com_github_googleapis_googleapis" 166 # "com_github_googlecloudplatform_google_cloud_cpp" 167 "com_github_grpc_grpc" 168 "com_google_protobuf" 169 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 170 # "com_googlesource_code_re2" 171 "curl" 172 "cython" 173 "dill_archive" 174 "double_conversion" 175 "enum34_archive" 176 "flatbuffers" 177 "functools32_archive" 178 "gast_archive" 179 "gif" 180 "hwloc" 181 "icu" 182 "jsoncpp_git" 183 "libjpeg_turbo" 184 "lmdb" 185 "nasm" 186 # "nsync" # not packaged in nixpkgs 187 "opt_einsum_archive" 188 "org_sqlite" 189 "pasta" 190 "pcre" 191 "png" 192 "pybind11" 193 "six_archive" 194 "snappy" 195 "tblib_archive" 196 "termcolor_archive" 197 "typing_extensions_archive" 198 "wrapt" 199 "zlib" 200 ]; 201 202 # Make sure Bazel knows about our configuration flags during fetching so that the 203 # relevant dependencies can be downloaded. 204 bazelFetchFlags = bazel-build.bazelBuildFlags; 205 206 bazelBuildFlags = [ 207 "-c opt" 208 ] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ 209 "--config=avx_posix" 210 ] ++ lib.optional cudaSupport [ 211 "--config=cuda" 212 ] ++ lib.optional mklSupport [ 213 "--config=mkl_open_source_only" 214 ]; 215 216 fetchAttrs = { 217 sha256 = 218 if cudaSupport then 219 "0d2rqwk9n4a6c51m4g21rxymv85kw2sdksni30cdx3pdcdbqgic7" 220 else 221 "0q540mwmh7grig0qq48ynzqi0gynimxnrq7k97wribqpkx99k39d"; 222 }; 223 224 buildAttrs = { 225 outputs = [ "out" ]; 226 227 # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. 228 # 1) Fix pybind11 include paths. 229 # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions 230 # in the same python program due to duplicate protobuf DBs. 231 # 3) Patch python path in the compiler driver. 232 # 4) Patch tensorflow sources to work with later versions of protobuf. See 233 # https://github.com/google/jax/issues/9534. Note that this should be 234 # removed on the next release after 0.3.0. 235 preBuild = '' 236 for src in ./jaxlib/*.{cc,h}; do 237 sed -i 's@include/pybind11@pybind11@g' $src 238 done 239 sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 240 sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD 241 substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \ 242 --replace "status.message()" "std::string{status.message()}" 243 '' + lib.optionalString cudaSupport '' 244 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl 245 ''; 246 247 installPhase = '' 248 ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch} 249 ''; 250 }; 251 252 inherit meta; 253 }; 254 255in 256buildPythonPackage { 257 inherit meta pname version; 258 format = "wheel"; 259 260 src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl"; 261 262 # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. 263 # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for 264 # more info. 265 postInstall = lib.optionalString cudaSupport '' 266 mkdir -p $out/bin 267 ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas 268 269 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 270 addOpenGLRunpath "$lib" 271 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" 272 done 273 ''; 274 275 nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath; 276 277 propagatedBuildInputs = [ 278 absl-py 279 double-conversion 280 flatbuffers 281 giflib 282 grpc 283 jsoncpp 284 libjpeg_turbo 285 numpy 286 scipy 287 six 288 snappy 289 ]; 290 291 pythonImportsCheck = [ "jaxlib" ]; 292 293 # Without it there are complaints about libcudart.so.11.0 not being found 294 # because RPATH path entries added above are stripped. 295 dontPatchELF = cudaSupport; 296}