1{ 2 lib, 3 clr, 4 composable_kernel_base, 5}: 6 7let 8 parts = { 9 _mha = { 10 # mha takes ~3hrs on 64 cores on an EPYC milan system at ~2.5GHz 11 # big-parallel builders are one gen newer and clocked ~30% higher but only 24 cores 12 # Should be <10h timeout but might be cutting it close 13 # TODO: work out how to split this into smaller chunks instead of all 3k mha instances together 14 # mha_0,1,2, search ninja target file for the individual instances, split by the index? 15 # TODO: can we prune the generated instances down to only what in practice are used with popular models 16 # when using flash-attention + MHA kernels? 17 targets = [ 18 "device_mha_instance" 19 ]; 20 requiredSystemFeatures = [ "big-parallel" ]; 21 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 22 }; 23 gemm_multiply_multiply = { 24 targets = [ 25 "device_gemm_multiply_multiply_instance" 26 ]; 27 requiredSystemFeatures = [ "big-parallel" ]; 28 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 29 }; 30 grouped_conv = { 31 targets = [ 32 "device_grouped_conv1d_bwd_weight_instance" 33 "device_grouped_conv2d_bwd_data_instance" 34 "device_grouped_conv2d_bwd_weight_instance" 35 "device_grouped_conv1d_fwd_instance" 36 "device_grouped_conv2d_fwd_instance" 37 "device_grouped_conv2d_fwd_dynamic_op_instance" 38 ]; 39 requiredSystemFeatures = [ "big-parallel" ]; 40 }; 41 grouped_conv_bwd_3d = { 42 targets = [ 43 "device_grouped_conv3d_bwd_data_instance" 44 "device_grouped_conv3d_bwd_data_bilinear_instance" 45 "device_grouped_conv3d_bwd_data_scale_instance" 46 "device_grouped_conv3d_bwd_weight_instance" 47 "device_grouped_conv3d_bwd_weight_bilinear_instance" 48 "device_grouped_conv3d_bwd_weight_scale_instance" 49 ]; 50 requiredSystemFeatures = [ "big-parallel" ]; 51 }; 52 grouped_conv_fwd_3d = { 53 targets = [ 54 "device_grouped_conv3d_fwd_instance" 55 "device_grouped_conv3d_fwd_bilinear_instance" 56 "device_grouped_conv3d_fwd_convinvscale_instance" 57 "device_grouped_conv3d_fwd_convscale_instance" 58 "device_grouped_conv3d_fwd_convscale_add_instance" 59 "device_grouped_conv3d_fwd_convscale_relu_instance" 60 "device_grouped_conv3d_fwd_dynamic_op_instance" 61 "device_grouped_conv3d_fwd_scale_instance" 62 "device_grouped_conv3d_fwd_scaleadd_ab_instance" 63 "device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance" 64 ]; 65 requiredSystemFeatures = [ "big-parallel" ]; 66 }; 67 batched_gemm = { 68 targets = [ 69 "device_batched_gemm_instance" 70 "device_batched_gemm_add_relu_gemm_add_instance" 71 "device_batched_gemm_bias_permute_instance" 72 "device_batched_gemm_gemm_instance" 73 "device_batched_gemm_reduce_instance" 74 "device_batched_gemm_softmax_gemm_instance" 75 "device_batched_gemm_softmax_gemm_permute_instance" 76 "device_grouped_gemm_instance" 77 "device_grouped_gemm_bias_instance" 78 "device_grouped_gemm_fastgelu_instance" 79 "device_grouped_gemm_fixed_nk_instance" 80 "device_grouped_gemm_fixed_nk_multi_abd_instance" 81 "device_grouped_gemm_tile_loop_instance" 82 ]; 83 requiredSystemFeatures = [ "big-parallel" ]; 84 }; 85 gemm_universal = { 86 targets = [ 87 "device_gemm_universal_instance" 88 "device_gemm_universal_batched_instance" 89 "device_gemm_universal_reduce_instance" 90 "device_gemm_universal_streamk_instance" 91 ]; 92 requiredSystemFeatures = [ "big-parallel" ]; 93 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 94 }; 95 gemm_other = { 96 targets = [ 97 "device_gemm_instance" 98 "device_gemm_ab_scale_instance" 99 "device_gemm_add_instance" 100 "device_gemm_add_add_fastgelu_instance" 101 "device_gemm_add_fastgelu_instance" 102 "device_gemm_add_multiply_instance" 103 "device_gemm_add_relu_instance" 104 "device_gemm_add_relu_add_layernorm_instance" 105 "device_gemm_add_silu_instance" 106 "device_gemm_bias_add_reduce_instance" 107 "device_gemm_bilinear_instance" 108 "device_gemm_fastgelu_instance" 109 "device_gemm_multi_abd_instance" 110 "device_gemm_multiply_add_instance" 111 "device_gemm_reduce_instance" 112 "device_gemm_splitk_instance" 113 "device_gemm_streamk_instance" 114 ]; 115 requiredSystemFeatures = [ "big-parallel" ]; 116 }; 117 conv = { 118 targets = [ 119 "device_conv1d_bwd_data_instance" 120 "device_conv2d_bwd_data_instance" 121 "device_conv2d_fwd_instance" 122 "device_conv2d_fwd_bias_relu_instance" 123 "device_conv2d_fwd_bias_relu_add_instance" 124 "device_conv3d_bwd_data_instance" 125 ]; 126 requiredSystemFeatures = [ "big-parallel" ]; 127 }; 128 pool = { 129 targets = [ 130 "device_avg_pool2d_bwd_instance" 131 "device_avg_pool3d_bwd_instance" 132 "device_pool2d_fwd_instance" 133 "device_pool3d_fwd_instance" 134 "device_max_pool_bwd_instance" 135 ]; 136 }; 137 other1 = { 138 targets = [ 139 "device_batchnorm_instance" 140 "device_contraction_bilinear_instance" 141 "device_contraction_scale_instance" 142 "device_elementwise_instance" 143 "device_elementwise_normalization_instance" 144 "device_normalization_bwd_data_instance" 145 "device_normalization_bwd_gamma_beta_instance" 146 "device_normalization_fwd_instance" 147 ]; 148 requiredSystemFeatures = [ "big-parallel" ]; 149 }; 150 other2 = { 151 targets = [ 152 "device_column_to_image_instance" 153 "device_image_to_column_instance" 154 "device_permute_scale_instance" 155 "device_quantization_instance" 156 "device_reduce_instance" 157 "device_softmax_instance" 158 "device_transpose_instance" 159 ]; 160 requiredSystemFeatures = [ "big-parallel" ]; 161 }; 162 }; 163 tensorOpBuilder = 164 { 165 part, 166 targets, 167 extraCmakeFlags ? [ ], 168 requiredSystemFeatures ? [ ], 169 }: 170 composable_kernel_base.overrideAttrs (old: { 171 inherit requiredSystemFeatures; 172 pname = "composable_kernel${clr.gpuArchSuffix}-${part}"; 173 makeTargets = targets; 174 preBuild = '' 175 echo "Building ${part}" 176 makeFlagsArray+=($makeTargets) 177 substituteInPlace Makefile \ 178 --replace-fail '.NOTPARALLEL:' "" 179 ''; 180 181 # Compile parallelism adjusted based on available RAM 182 # Never uses less than NIX_BUILD_CORES/4, never uses more than NIX_BUILD_CORES 183 # CK uses an unusually high amount of memory per core in the build step 184 # Nix/nixpkgs doesn't really have any infra to tell it that this build is unusually memory hungry 185 # So, bodge. Otherwise you end up having to build all of ROCm with a low core limit when 186 # it's only this package that has trouble. 187 preConfigure = old.preConfigure or "" + '' 188 MEM_GB_TOTAL=$(awk '/MemTotal/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo) 189 MEM_GB_AVAILABLE=$(awk '/MemAvailable/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo) 190 APPX_GB=$((MEM_GB_AVAILABLE > MEM_GB_TOTAL ? MEM_GB_TOTAL : MEM_GB_AVAILABLE)) 191 MAX_CORES=$((1 + APPX_GB/3)) 192 MAX_CORES=$((MAX_CORES < NIX_BUILD_CORES/3 ? NIX_BUILD_CORES/3 : MAX_CORES)) 193 export NIX_BUILD_CORES="$((NIX_BUILD_CORES > MAX_CORES ? MAX_CORES : NIX_BUILD_CORES))" 194 echo "Picked new core limit NIX_BUILD_CORES=$NIX_BUILD_CORES based on available mem: $APPX_GB GB" 195 cmakeFlagsArray+=( 196 "-DCK_PARALLEL_COMPILE_JOBS=$NIX_BUILD_CORES" 197 ) 198 ''; 199 cmakeFlags = old.cmakeFlags ++ extraCmakeFlags; 200 # Early exit after build phase with success, skips fixups etc 201 # Will get copied back into /build of the final CK 202 postBuild = '' 203 find . -name "*.o" -type f | while read -r file; do 204 mkdir -p "$out/$(dirname "$file")" 205 cp --reflink=auto "$file" "$out/$file" 206 done 207 exit 0 208 ''; 209 meta = old.meta // { 210 broken = false; 211 }; 212 }); 213 composable_kernel_parts = builtins.mapAttrs ( 214 part: targets: tensorOpBuilder (targets // { inherit part; }) 215 ) parts; 216in 217 218composable_kernel_base.overrideAttrs ( 219 finalAttrs: old: { 220 pname = "composable_kernel${clr.gpuArchSuffix}"; 221 parts_dirs = builtins.attrValues composable_kernel_parts; 222 disallowedReferences = builtins.attrValues composable_kernel_parts; 223 preBuild = '' 224 for dir in $parts_dirs; do 225 find "$dir" -type f -name "*.o" | while read -r file; do 226 # Extract the relative path by removing the output directory prefix 227 rel_path="''${file#"$dir/"}" 228 229 # Create parent directory if it doesn't exist 230 mkdir -p "$(dirname "$rel_path")" 231 232 # Copy the file back to its original location, give it a future timestamp 233 # so make treats it as up to date 234 cp --reflink=auto --no-preserve=all "$file" "$rel_path" 235 touch -d "now +10 hours" "$rel_path" 236 done 237 done 238 ''; 239 passthru = old.passthru // { 240 parts = composable_kernel_parts; 241 }; 242 meta = old.meta // { 243 # Builds which don't don't target any gfx9 cause cmake errors in dependent projects 244 broken = !finalAttrs.passthru.anyGfx9Target; 245 }; 246 } 247)