···1919# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
2020, config
2121, cudaSupport ? config.cudaSupport
2222-, cudaPackages ? { }
2323-, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities
2222+, cudaPackagesGoogle
2323+, cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities
2424, mklSupport ? false, mkl
2525, tensorboardSupport ? true
2626# XLA without CUDA is broken
···5050 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
5151 # translation units, so the build fails at link time
5252 stdenv =
5353- if cudaSupport then cudaPackages.backendStdenv
5353+ if cudaSupport then cudaPackagesGoogle.backendStdenv
5454 else if originalStdenv.isDarwin then llvmPackages_11.stdenv
5555 else originalStdenv;
5656- inherit (cudaPackages) cudatoolkit nccl;
5656+ inherit (cudaPackagesGoogle) cudatoolkit nccl;
5757 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
5858 # cudaPackages.cudnn led to this:
5959 # https://github.com/tensorflow/tensorflow/issues/60398
6060 cudnnAttribute = "cudnn_8_6";
6161- cudnn = cudaPackages.${cudnnAttribute};
6161+ cudnn = cudaPackagesGoogle.${cudnnAttribute};
6262 gentoo-patches = fetchzip {
6363 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
6464 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
···486486 broken =
487487 stdenv.isDarwin
488488 || !(xlaSupport -> cudaSupport)
489489- || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
490490- || !(cudaSupport -> cudaPackages ? cudatoolkit);
489489+ || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle)
490490+ || !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit);
491491 } // lib.optionalAttrs stdenv.isDarwin {
492492 timeout = 86400; # 24 hours
493493 maxSilent = 14400; # 4h, double the default of 7200s
···590590 # Regression test for #77626 removed because not more `tensorflow.contrib`.
591591592592 passthru = {
593593- inherit cudaPackages;
593593+ cudaPackages = cudaPackagesGoogle;
594594 deps = bazel-build.deps;
595595 libtensorflow = bazel-build.out;
596596 };
+4
pkgs/top-level/all-packages.nix
···73187318 cudaPackages_12_2 = callPackage ./cuda-packages.nix { cudaVersion = "12.2"; };
73197319 cudaPackages_12 = cudaPackages_12_0;
7320732073217321+ # Use the older cudaPackages for tensorflow and jax, as determined by cudnn
73227322+ # compatibility: https://www.tensorflow.org/install/source#gpu
73237323+ cudaPackagesGoogle = cudaPackages_11;
73247324+73217325 # TODO: try upgrading once there is a cuDNN release supporting CUDA 12. No
73227326 # such cuDNN release as of 2023-01-10.
73237327 cudaPackages = recurseIntoAttrs cudaPackages_11;