Merge pull request #225661 from SomeoneSerge/jax-libstdcxx

python3Packages.jax: fix libstdc++ mismatch when built with CUDA

authored by

Samuel Ainsworth and committed by
GitHub
929a328d 0fac1123

+39 -16
+8 -2
pkgs/development/compilers/cudatoolkit/extension.nix
··· 18 18 # E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++ 19 19 # Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context 20 20 backendStdenv = final.callPackage ./stdenv.nix { 21 - nixpkgsStdenv = prev.pkgs.stdenv; 22 - nvccCompatibleStdenv = prev.pkgs.buildPackages."${finalVersion.gcc}Stdenv"; 21 + # We use buildPackages (= pkgsBuildHost) because we look for a gcc that 22 + # runs on our build platform, and that produces executables for the host 23 + # platform (= platform on which we deploy and run the downstream packages). 24 + # The target platform of buildPackages.gcc is our host platform, so its 25 + # .lib output should be the libstdc++ we want to be writing in the runpaths 26 + # Cf. https://github.com/NixOS/nixpkgs/pull/225661#discussion_r1164564576 27 + nixpkgsCompatibleLibstdcxx = final.pkgs.buildPackages.gcc.cc.lib; 28 + nvccCompatibleCC = final.pkgs.buildPackages."${finalVersion.gcc}".cc; 23 29 }; 24 30 25 31 ### Add classic cudatoolkit package
+28 -12
pkgs/development/compilers/cudatoolkit/stdenv.nix
··· 1 - { nixpkgsStdenv 2 - , nvccCompatibleStdenv 1 + { lib 2 + , nixpkgsCompatibleLibstdcxx 3 + , nvccCompatibleCC 3 4 , overrideCC 5 + , stdenv 4 6 , wrapCCWith 5 7 }: 6 8 7 - overrideCC nixpkgsStdenv (wrapCCWith { 8 - cc = nvccCompatibleStdenv.cc.cc; 9 + let 10 + cc = wrapCCWith 11 + { 12 + cc = nvccCompatibleCC; 9 13 10 - # This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++. 11 - # Note that libstdc++ maintains forward-compatibility: if we load a newer 12 - # libstdc++ into the process, we can still use libraries built against an 13 - # older libstdc++. This, in practice, means that we should use libstdc++ from 14 - # the same stdenv that the rest of nixpkgs uses. 15 - # We currently do not try to support anything other than gcc and linux. 16 - libcxx = nixpkgsStdenv.cc.cc.lib; 17 - }) 14 + # This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++. 15 + # Note that libstdc++ maintains forward-compatibility: if we load a newer 16 + # libstdc++ into the process, we can still use libraries built against an 17 + # older libstdc++. This, in practice, means that we should use libstdc++ from 18 + # the same stdenv that the rest of nixpkgs uses. 19 + # We currently do not try to support anything other than gcc and linux. 20 + libcxx = nixpkgsCompatibleLibstdcxx; 21 + }; 22 + cudaStdenv = overrideCC stdenv cc; 23 + passthruExtra = { 24 + inherit nixpkgsCompatibleLibstdcxx; 25 + # cc already exposed 26 + }; 27 + assertCondition = true; 28 + in 29 + lib.extendDerivation 30 + assertCondition 31 + passthruExtra 32 + cudaStdenv 33 +
+3 -2
pkgs/development/python-modules/jaxlib/default.nix
··· 49 49 }: 50 50 51 51 let 52 - inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl; 52 + inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; 53 53 54 54 pname = "jaxlib"; 55 55 version = "0.3.22"; ··· 81 81 cudatoolkit_cc_joined = symlinkJoin { 82 82 name = "${cudatoolkit.cc.name}-merged"; 83 83 paths = [ 84 - cudatoolkit.cc 84 + backendStdenv.cc 85 85 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 86 86 ]; 87 87 }; ··· 271 271 sed -i 's@include/pybind11@pybind11@g' $src 272 272 done 273 273 '' + lib.optionalString cudaSupport '' 274 + export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" 274 275 patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl 275 276 '' + lib.optionalString stdenv.isDarwin '' 276 277 # Framework search paths aren't added by bintools hook