lol
1{ lib
2, backendStdenv
3, fetchFromGitHub
4, python3
5, which
6, autoAddOpenGLRunpathHook
7, cuda_cccl
8, cuda_cudart
9, cuda_nvcc
10, cudaFlags
11, cudaVersion
12# passthru.updateScript
13, gitUpdater
14}:
15let
16 # Output looks like "-gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86"
17 gencode = lib.concatStringsSep " " cudaFlags.gencode;
18in
19backendStdenv.mkDerivation (finalAttrs: {
20 pname = "nccl";
21 version = "2.19.3-1";
22
23 src = fetchFromGitHub {
24 owner = "NVIDIA";
25 repo = finalAttrs.pname;
26 rev = "v${finalAttrs.version}";
27 hash = "sha256-59FlOKM5EB5Vkm4dZBRCkn+IgIcdQehE+FyZAdTCT/A=";
28 };
29
30 outputs = [ "out" "dev" ];
31
32 nativeBuildInputs = [
33 which
34 autoAddOpenGLRunpathHook
35 cuda_nvcc
36 python3
37 ];
38
39 buildInputs = [
40 cuda_cudart
41 ]
42 # NOTE: CUDA versions in Nixpkgs only use a major and minor version. When we do comparisons
43 # against other version, like below, it's important that we use the same format. Otherwise,
44 # we'll get incorrect results.
45 # For example, lib.versionAtLeast "12.0" "12.0.0" == false.
46 ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [
47 cuda_cccl
48 ];
49
50 preConfigure = ''
51 patchShebangs ./src/device/generate.py
52 makeFlagsArray+=(
53 "NVCC_GENCODE=${gencode}"
54 )
55 '';
56
57 makeFlags = [
58 "CUDA_HOME=${cuda_nvcc}"
59 "CUDA_LIB=${lib.getLib cuda_cudart}/lib"
60 "CUDA_INC=${lib.getDev cuda_cudart}/include"
61 "PREFIX=$(out)"
62 ];
63
64 postFixup = ''
65 moveToOutput lib/libnccl_static.a $dev
66 '';
67
68 env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ];
69
70 # Run the update script with: `nix-shell maintainers/scripts/update.nix --argstr package cudaPackages.nccl`
71 passthru.updateScript = gitUpdater {
72 inherit (finalAttrs) pname version;
73 rev-prefix = "v";
74 };
75
76 enableParallelBuilding = true;
77
78 meta = with lib; {
79 description = "Multi-GPU and multi-node collective communication primitives for NVIDIA GPUs";
80 homepage = "https://developer.nvidia.com/nccl";
81 license = licenses.bsd3;
82 platforms = platforms.linux;
83 maintainers = with maintainers; [ mdaiter orivej ] ++ teams.cuda.members;
84 };
85})