1{ config
2, lib
3, cudaVersion
4}:
5
6# Type aliases
7# Gpu :: AttrSet
8# - See the documentation in ./gpus.nix.
9
10let
11 inherit (lib) attrsets lists strings trivial versions;
12
13 # Flags are determined based on your CUDA toolkit by default. You may benefit
14 # from improved performance, reduced file size, or greater hardware support by
15 # passing a configuration based on your specific GPU environment.
16 #
17 # config.cudaCapabilities :: List Capability
18 # List of hardware generations to build.
19 # E.g. [ "8.0" ]
20 # Currently, the last item is considered the optional forward-compatibility arch,
21 # but this may change in the future.
22 #
23 # config.cudaForwardCompat :: Bool
24 # Whether to include the forward compatibility gencode (+PTX)
25 # to support future GPU generations.
26 # E.g. true
27 #
28 # Please see the accompanying documentation or https://github.com/NixOS/nixpkgs/pull/205351
29
30 # gpus :: List Gpu
31 gpus = builtins.import ./gpus.nix;
32
33 # isSupported :: Gpu -> Bool
34 isSupported = gpu:
35 let
36 inherit (gpu) minCudaVersion maxCudaVersion;
37 lowerBoundSatisfied = strings.versionAtLeast cudaVersion minCudaVersion;
38 upperBoundSatisfied = (maxCudaVersion == null)
39 || !(strings.versionOlder maxCudaVersion cudaVersion);
40 in
41 lowerBoundSatisfied && upperBoundSatisfied;
42
43 # isDefault :: Gpu -> Bool
44 isDefault = gpu:
45 let
46 inherit (gpu) dontDefaultAfter;
47 newGpu = dontDefaultAfter == null;
48 recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
49 in
50 recentGpu;
51
52 # supportedGpus :: List Gpu
53 # GPUs which are supported by the provided CUDA version.
54 supportedGpus = builtins.filter isSupported gpus;
55
56 # defaultGpus :: List Gpu
57 # GPUs which are supported by the provided CUDA version and we want to build for by default.
58 defaultGpus = builtins.filter isDefault supportedGpus;
59
60 # supportedCapabilities :: List Capability
61 supportedCapabilities = lists.map (gpu: gpu.computeCapability) supportedGpus;
62
63 # defaultCapabilities :: List Capability
64 # The default capabilities to target, if not overridden by the user.
65 defaultCapabilities = lists.map (gpu: gpu.computeCapability) defaultGpus;
66
67 # cudaArchNameToVersions :: AttrSet String (List String)
68 # Maps the name of a GPU architecture to different versions of that architecture.
69 # For example, "Ampere" maps to [ "8.0" "8.6" "8.7" ].
70 cudaArchNameToVersions =
71 lists.groupBy'
72 (versions: gpu: versions ++ [ gpu.computeCapability ])
73 [ ]
74 (gpu: gpu.archName)
75 supportedGpus;
76
77 # cudaComputeCapabilityToName :: AttrSet String String
78 # Maps the version of a GPU architecture to the name of that architecture.
79 # For example, "8.0" maps to "Ampere".
80 cudaComputeCapabilityToName = builtins.listToAttrs (
81 lists.map
82 (gpu: {
83 name = gpu.computeCapability;
84 value = gpu.archName;
85 })
86 supportedGpus
87 );
88
89 # dropDot :: String -> String
90 dropDot = ver: builtins.replaceStrings [ "." ] [ "" ] ver;
91
92 # archMapper :: String -> List String -> List String
93 # Maps a feature across a list of architecture versions to produce a list of architectures.
94 # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "sm_80" "sm_86" "sm_87" ].
95 archMapper = feat: lists.map (computeCapability: "${feat}_${dropDot computeCapability}");
96
97 # gencodeMapper :: String -> List String -> List String
98 # Maps a feature across a list of architecture versions to produce a list of gencode arguments.
99 # For example, "sm" and [ "8.0" "8.6" "8.7" ] produces [ "-gencode=arch=compute_80,code=sm_80"
100 # "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_87,code=sm_87" ].
101 gencodeMapper = feat: lists.map (
102 computeCapability:
103 "-gencode=arch=compute_${dropDot computeCapability},code=${feat}_${dropDot computeCapability}"
104 );
105
106 formatCapabilities = { cudaCapabilities, enableForwardCompat ? true }: rec {
107 inherit cudaCapabilities enableForwardCompat;
108
109 # archNames :: List String
110 # E.g. [ "Turing" "Ampere" ]
111 archNames = lists.unique (builtins.map (cap: cudaComputeCapabilityToName.${cap}) cudaCapabilities);
112
113 # realArches :: List String
114 # The real architectures are physical architectures supported by the CUDA version.
115 # E.g. [ "sm_75" "sm_86" ]
116 realArches = archMapper "sm" cudaCapabilities;
117
118 # virtualArches :: List String
119 # The virtual architectures are typically used for forward compatibility, when trying to support
120 # an architecture newer than the CUDA version allows.
121 # E.g. [ "compute_75" "compute_86" ]
122 virtualArches = archMapper "compute" cudaCapabilities;
123
124 # arches :: List String
125 # By default, build for all supported architectures and forward compatibility via a virtual
126 # architecture for the newest supported architecture.
127 # E.g. [ "sm_75" "sm_86" "compute_86" ]
128 arches = realArches ++
129 lists.optional enableForwardCompat (lists.last virtualArches);
130
131 # gencode :: List String
132 # A list of CUDA gencode arguments to pass to NVCC.
133 # E.g. [ "-gencode=arch=compute_75,code=sm_75" ... "-gencode=arch=compute_86,code=compute_86" ]
134 gencode =
135 let
136 base = gencodeMapper "sm" cudaCapabilities;
137 forward = gencodeMapper "compute" [ (lists.last cudaCapabilities) ];
138 in
139 base ++ lib.optionals enableForwardCompat forward;
140 };
141
142in
143# When changing names or formats: pause, validate, and update the assert
144assert (formatCapabilities { cudaCapabilities = [ "7.5" "8.6" ]; }) == {
145 cudaCapabilities = [ "7.5" "8.6" ];
146 enableForwardCompat = true;
147
148 archNames = [ "Turing" "Ampere" ];
149 realArches = [ "sm_75" "sm_86" ];
150 virtualArches = [ "compute_75" "compute_86" ];
151 arches = [ "sm_75" "sm_86" "compute_86" ];
152
153 gencode = [ "-gencode=arch=compute_75,code=sm_75" "-gencode=arch=compute_86,code=sm_86" "-gencode=arch=compute_86,code=compute_86" ];
154};
155{
156 # formatCapabilities :: { cudaCapabilities: List Capability, cudaForwardCompat: Boolean } -> { ... }
157 inherit formatCapabilities;
158
159 # cudaArchNameToVersions :: String => String
160 inherit cudaArchNameToVersions;
161
162 # cudaComputeCapabilityToName :: String => String
163 inherit cudaComputeCapabilityToName;
164
165 # dropDot :: String -> String
166 inherit dropDot;
167} // formatCapabilities {
168 cudaCapabilities = config.cudaCapabilities or defaultCapabilities;
169 enableForwardCompat = config.cudaForwardCompat or true;
170}