1{
2 lib,
3 clr,
4 composable_kernel_base,
5}:
6
7let
8 parts = {
9 _mha = {
10 # mha takes ~3hrs on 64 cores on an EPYC milan system at ~2.5GHz
11 # big-parallel builders are one gen newer and clocked ~30% higher but only 24 cores
12 # Should be <10h timeout but might be cutting it close
13 # TODO: work out how to split this into smaller chunks instead of all 3k mha instances together
14 # mha_0,1,2, search ninja target file for the individual instances, split by the index?
15 # TODO: can we prune the generated instances down to only what in practice are used with popular models
16 # when using flash-attention + MHA kernels?
17 targets = [
18 "device_mha_instance"
19 ];
20 requiredSystemFeatures = [ "big-parallel" ];
21 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ];
22 };
23 gemm_multiply_multiply = {
24 targets = [
25 "device_gemm_multiply_multiply_instance"
26 ];
27 requiredSystemFeatures = [ "big-parallel" ];
28 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ];
29 };
30 grouped_conv = {
31 targets = [
32 "device_grouped_conv1d_bwd_weight_instance"
33 "device_grouped_conv2d_bwd_data_instance"
34 "device_grouped_conv2d_bwd_weight_instance"
35 "device_grouped_conv1d_fwd_instance"
36 "device_grouped_conv2d_fwd_instance"
37 "device_grouped_conv2d_fwd_dynamic_op_instance"
38 ];
39 requiredSystemFeatures = [ "big-parallel" ];
40 };
41 grouped_conv_bwd_3d = {
42 targets = [
43 "device_grouped_conv3d_bwd_data_instance"
44 "device_grouped_conv3d_bwd_data_bilinear_instance"
45 "device_grouped_conv3d_bwd_data_scale_instance"
46 "device_grouped_conv3d_bwd_weight_instance"
47 "device_grouped_conv3d_bwd_weight_bilinear_instance"
48 "device_grouped_conv3d_bwd_weight_scale_instance"
49 ];
50 requiredSystemFeatures = [ "big-parallel" ];
51 };
52 grouped_conv_fwd_3d = {
53 targets = [
54 "device_grouped_conv3d_fwd_instance"
55 "device_grouped_conv3d_fwd_bilinear_instance"
56 "device_grouped_conv3d_fwd_convinvscale_instance"
57 "device_grouped_conv3d_fwd_convscale_instance"
58 "device_grouped_conv3d_fwd_convscale_add_instance"
59 "device_grouped_conv3d_fwd_convscale_relu_instance"
60 "device_grouped_conv3d_fwd_dynamic_op_instance"
61 "device_grouped_conv3d_fwd_scale_instance"
62 "device_grouped_conv3d_fwd_scaleadd_ab_instance"
63 "device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance"
64 ];
65 requiredSystemFeatures = [ "big-parallel" ];
66 };
67 batched_gemm = {
68 targets = [
69 "device_batched_gemm_instance"
70 "device_batched_gemm_add_relu_gemm_add_instance"
71 "device_batched_gemm_bias_permute_instance"
72 "device_batched_gemm_gemm_instance"
73 "device_batched_gemm_reduce_instance"
74 "device_batched_gemm_softmax_gemm_instance"
75 "device_batched_gemm_softmax_gemm_permute_instance"
76 "device_grouped_gemm_instance"
77 "device_grouped_gemm_bias_instance"
78 "device_grouped_gemm_fastgelu_instance"
79 "device_grouped_gemm_fixed_nk_instance"
80 "device_grouped_gemm_fixed_nk_multi_abd_instance"
81 "device_grouped_gemm_tile_loop_instance"
82 ];
83 requiredSystemFeatures = [ "big-parallel" ];
84 };
85 gemm_universal = {
86 targets = [
87 "device_gemm_universal_instance"
88 "device_gemm_universal_batched_instance"
89 "device_gemm_universal_reduce_instance"
90 "device_gemm_universal_streamk_instance"
91 ];
92 requiredSystemFeatures = [ "big-parallel" ];
93 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ];
94 };
95 gemm_other = {
96 targets = [
97 "device_gemm_instance"
98 "device_gemm_ab_scale_instance"
99 "device_gemm_add_instance"
100 "device_gemm_add_add_fastgelu_instance"
101 "device_gemm_add_fastgelu_instance"
102 "device_gemm_add_multiply_instance"
103 "device_gemm_add_relu_instance"
104 "device_gemm_add_relu_add_layernorm_instance"
105 "device_gemm_add_silu_instance"
106 "device_gemm_bias_add_reduce_instance"
107 "device_gemm_bilinear_instance"
108 "device_gemm_fastgelu_instance"
109 "device_gemm_multi_abd_instance"
110 "device_gemm_multiply_add_instance"
111 "device_gemm_reduce_instance"
112 "device_gemm_splitk_instance"
113 "device_gemm_streamk_instance"
114 ];
115 requiredSystemFeatures = [ "big-parallel" ];
116 };
117 conv = {
118 targets = [
119 "device_conv1d_bwd_data_instance"
120 "device_conv2d_bwd_data_instance"
121 "device_conv2d_fwd_instance"
122 "device_conv2d_fwd_bias_relu_instance"
123 "device_conv2d_fwd_bias_relu_add_instance"
124 "device_conv3d_bwd_data_instance"
125 ];
126 requiredSystemFeatures = [ "big-parallel" ];
127 };
128 pool = {
129 targets = [
130 "device_avg_pool2d_bwd_instance"
131 "device_avg_pool3d_bwd_instance"
132 "device_pool2d_fwd_instance"
133 "device_pool3d_fwd_instance"
134 "device_max_pool_bwd_instance"
135 ];
136 };
137 other1 = {
138 targets = [
139 "device_batchnorm_instance"
140 "device_contraction_bilinear_instance"
141 "device_contraction_scale_instance"
142 "device_elementwise_instance"
143 "device_elementwise_normalization_instance"
144 "device_normalization_bwd_data_instance"
145 "device_normalization_bwd_gamma_beta_instance"
146 "device_normalization_fwd_instance"
147 ];
148 requiredSystemFeatures = [ "big-parallel" ];
149 };
150 other2 = {
151 targets = [
152 "device_column_to_image_instance"
153 "device_image_to_column_instance"
154 "device_permute_scale_instance"
155 "device_quantization_instance"
156 "device_reduce_instance"
157 "device_softmax_instance"
158 "device_transpose_instance"
159 ];
160 requiredSystemFeatures = [ "big-parallel" ];
161 };
162 };
163 tensorOpBuilder =
164 {
165 part,
166 targets,
167 extraCmakeFlags ? [ ],
168 requiredSystemFeatures ? [ ],
169 }:
170 composable_kernel_base.overrideAttrs (old: {
171 inherit requiredSystemFeatures;
172 pname = "composable_kernel${clr.gpuArchSuffix}-${part}";
173 makeTargets = targets;
174 preBuild = ''
175 echo "Building ${part}"
176 makeFlagsArray+=($makeTargets)
177 substituteInPlace Makefile \
178 --replace-fail '.NOTPARALLEL:' ""
179 '';
180
181 # Compile parallelism adjusted based on available RAM
182 # Never uses less than NIX_BUILD_CORES/4, never uses more than NIX_BUILD_CORES
183 # CK uses an unusually high amount of memory per core in the build step
184 # Nix/nixpkgs doesn't really have any infra to tell it that this build is unusually memory hungry
185 # So, bodge. Otherwise you end up having to build all of ROCm with a low core limit when
186 # it's only this package that has trouble.
187 preConfigure =
188 old.preConfigure or ""
189 + ''
190 MEM_GB_TOTAL=$(awk '/MemTotal/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo)
191 MEM_GB_AVAILABLE=$(awk '/MemAvailable/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo)
192 APPX_GB=$((MEM_GB_AVAILABLE > MEM_GB_TOTAL ? MEM_GB_TOTAL : MEM_GB_AVAILABLE))
193 MAX_CORES=$((1 + APPX_GB/3))
194 MAX_CORES=$((MAX_CORES < NIX_BUILD_CORES/3 ? NIX_BUILD_CORES/3 : MAX_CORES))
195 export NIX_BUILD_CORES="$((NIX_BUILD_CORES > MAX_CORES ? MAX_CORES : NIX_BUILD_CORES))"
196 echo "Picked new core limit NIX_BUILD_CORES=$NIX_BUILD_CORES based on available mem: $APPX_GB GB"
197 cmakeFlagsArray+=(
198 "-DCK_PARALLEL_COMPILE_JOBS=$NIX_BUILD_CORES"
199 )
200 '';
201 cmakeFlags = old.cmakeFlags ++ extraCmakeFlags;
202 # Early exit after build phase with success, skips fixups etc
203 # Will get copied back into /build of the final CK
204 postBuild = ''
205 find . -name "*.o" -type f | while read -r file; do
206 mkdir -p "$out/$(dirname "$file")"
207 cp --reflink=auto "$file" "$out/$file"
208 done
209 exit 0
210 '';
211 meta = old.meta // {
212 broken = false;
213 };
214 });
215 composable_kernel_parts = builtins.mapAttrs (
216 part: targets: tensorOpBuilder (targets // { inherit part; })
217 ) parts;
218in
219
220composable_kernel_base.overrideAttrs (
221 finalAttrs: old: {
222 pname = "composable_kernel${clr.gpuArchSuffix}";
223 parts_dirs = builtins.attrValues composable_kernel_parts;
224 disallowedReferences = builtins.attrValues composable_kernel_parts;
225 preBuild = ''
226 for dir in $parts_dirs; do
227 find "$dir" -type f -name "*.o" | while read -r file; do
228 # Extract the relative path by removing the output directory prefix
229 rel_path="''${file#"$dir/"}"
230
231 # Create parent directory if it doesn't exist
232 mkdir -p "$(dirname "$rel_path")"
233
234 # Copy the file back to its original location, give it a future timestamp
235 # so make treats it as up to date
236 cp --reflink=auto --no-preserve=all "$file" "$rel_path"
237 touch -d "now +10 hours" "$rel_path"
238 done
239 done
240 '';
241 passthru = old.passthru // {
242 parts = composable_kernel_parts;
243 };
244 meta = old.meta // {
245 # Builds which don't don't target any gfx9 cause cmake errors in dependent projects
246 broken = !finalAttrs.passthru.anyGfx9Target;
247 };
248 }
249)