nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at r-updates 319 lines 11 kB view raw
1{ 2 lib, 3 clr, 4 composable_kernel_base, 5}: 6 7let 8 inherit (composable_kernel_base) miOpenReqLibsOnly; 9 parts = { 10 _mha = { 11 enabled = !miOpenReqLibsOnly; 12 # mha takes ~3hrs on 64 cores on an EPYC milan system at ~2.5GHz 13 # big-parallel builders are one gen newer and clocked ~30% higher but only 24 cores 14 # Should be <10h timeout but might be cutting it close 15 # TODO: work out how to split this into smaller chunks instead of all 3k mha instances together 16 # mha_0,1,2, search ninja target file for the individual instances, split by the index? 17 # TODO: can we prune the generated instances down to only what in practice are used with popular models 18 # when using flash-attention + MHA kernels? 19 targets = [ 20 "device_mha_instance" 21 ]; 22 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 23 }; 24 gemm_multiply_multiply = { 25 enabled = !miOpenReqLibsOnly; 26 targets = [ 27 "device_gemm_multiply_multiply_instance" 28 ]; 29 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 30 onlyFor = [ 31 "gfx942" 32 "gfx950" 33 ]; 34 }; 35 gemm_multiply_multiply_wp = { 36 enabled = !miOpenReqLibsOnly; 37 targets = [ 38 "device_gemm_multiply_multiply_wp_instance" 39 ]; 40 extraCmakeFlags = [ "-DHIP_CLANG_NUM_PARALLEL_JOBS=2" ]; 41 onlyFor = [ 42 "gfx942" 43 "gfx950" 44 ]; 45 }; 46 grouped_conv_bwd = { 47 targets = [ 48 "device_grouped_conv1d_bwd_weight_instance" 49 "device_grouped_conv2d_bwd_data_instance" 50 "device_grouped_conv2d_bwd_weight_instance" 51 ]; 52 }; 53 grouped_conv_fwd = { 54 targets = [ 55 "device_grouped_conv1d_fwd_instance" 56 "device_grouped_conv2d_fwd_instance" 57 "device_grouped_conv2d_fwd_bias_bnorm_clamp_instance" 58 "device_grouped_conv2d_fwd_bias_clamp_instance" 59 "device_grouped_conv2d_fwd_clamp_instance" 60 "device_grouped_conv2d_fwd_dynamic_op_instance" 61 ]; 62 }; 63 grouped_conv_bwd_3d1 = { 64 targets = [ 65 "device_grouped_conv3d_bwd_data_instance" 66 "device_grouped_conv3d_bwd_data_bilinear_instance" 67 "device_grouped_conv3d_bwd_data_scale_instance" 68 ]; 69 }; 70 grouped_conv_bwd_3d2 = { 71 targets = [ 72 "device_grouped_conv3d_bwd_weight_instance" 73 "device_grouped_conv3d_bwd_weight_bilinear_instance" 74 "device_grouped_conv3d_bwd_weight_scale_instance" 75 ]; 76 }; 77 grouped_conv_fwd_3d1 = { 78 targets = [ 79 "device_grouped_conv3d_fwd_instance" 80 "device_grouped_conv3d_fwd_clamp_instance" 81 "device_grouped_conv3d_fwd_bias_bnorm_clamp_instance" 82 "device_grouped_conv3d_fwd_bias_clamp_instance" 83 "device_grouped_conv3d_fwd_bilinear_instance" 84 "device_grouped_conv3d_fwd_convinvscale_instance" 85 "device_grouped_conv3d_fwd_convscale_instance" 86 "device_grouped_conv3d_fwd_convscale_add_instance" 87 ]; 88 }; 89 grouped_conv_fwd_3d2 = { 90 targets = [ 91 "device_grouped_conv3d_fwd_convscale_relu_instance" 92 "device_grouped_conv3d_fwd_dynamic_op_instance" 93 "device_grouped_conv3d_fwd_scale_instance" 94 "device_grouped_conv3d_fwd_scaleadd_ab_instance" 95 "device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance" 96 ]; 97 }; 98 grouped_conv_fwd_nd = { 99 targets = [ 100 "device_grouped_convnd_bwd_weight_instance" 101 ]; 102 }; 103 batched_gemm1 = { 104 enabled = !miOpenReqLibsOnly; 105 targets = [ 106 "device_batched_gemm_instance" 107 "device_batched_gemm_b_scale_instance" 108 "device_batched_gemm_multi_d_instance" 109 "device_batched_gemm_add_relu_gemm_add_instance" 110 "device_batched_gemm_bias_permute_instance" 111 "device_batched_gemm_gemm_instance" 112 "device_batched_gemm_reduce_instance" 113 "device_batched_gemm_softmax_gemm_instance" 114 ]; 115 }; 116 batched_gemm2 = { 117 enabled = !miOpenReqLibsOnly; 118 targets = [ 119 "device_batched_gemm_softmax_gemm_permute_instance" 120 "device_grouped_gemm_instance" 121 "device_grouped_gemm_bias_instance" 122 "device_grouped_gemm_fastgelu_instance" 123 "device_grouped_gemm_fixed_nk_instance" 124 "device_grouped_gemm_fixed_nk_multi_abd_instance" 125 "device_grouped_gemm_tile_loop_instance" 126 ]; 127 }; 128 gemm_universal1 = { 129 enabled = !miOpenReqLibsOnly; 130 targets = [ 131 "device_gemm_universal_instance" 132 "device_gemm_universal_batched_instance" 133 ]; 134 }; 135 gemm_universal2 = { 136 enabled = !miOpenReqLibsOnly; 137 targets = [ 138 "device_gemm_universal_reduce_instance" 139 "device_gemm_universal_streamk_instance" 140 ]; 141 }; 142 gemm_other1 = { 143 enabled = !miOpenReqLibsOnly; 144 targets = [ 145 "device_gemm_instance" 146 "device_gemm_b_scale_instance" 147 "device_gemm_ab_scale_instance" 148 "device_gemm_add_instance" 149 "device_gemm_add_add_fastgelu_instance" 150 "device_gemm_add_fastgelu_instance" 151 "device_gemm_add_multiply_instance" 152 "device_gemm_add_relu_instance" 153 ]; 154 }; 155 gemm_other2 = { 156 enabled = !miOpenReqLibsOnly; 157 targets = [ 158 "device_gemm_add_relu_add_layernorm_instance" 159 "device_gemm_add_silu_instance" 160 "device_gemm_bias_add_reduce_instance" 161 "device_gemm_bilinear_instance" 162 "device_gemm_fastgelu_instance" 163 "device_gemm_multi_abd_instance" 164 "device_gemm_multiply_add_instance" 165 "device_gemm_reduce_instance" 166 "device_gemm_splitk_instance" 167 "device_gemm_streamk_instance" 168 ]; 169 }; 170 conv = { 171 targets = [ 172 "device_conv1d_bwd_data_instance" 173 "device_conv2d_bwd_data_instance" 174 "device_conv2d_fwd_instance" 175 "device_conv2d_fwd_bias_relu_instance" 176 "device_conv2d_fwd_bias_relu_add_instance" 177 "device_conv3d_bwd_data_instance" 178 ]; 179 }; 180 pool = { 181 enabled = !miOpenReqLibsOnly; 182 targets = [ 183 "device_avg_pool2d_bwd_instance" 184 "device_avg_pool3d_bwd_instance" 185 "device_pool2d_fwd_instance" 186 "device_pool3d_fwd_instance" 187 "device_max_pool_bwd_instance" 188 ]; 189 }; 190 other0 = { 191 targets = [ 192 "device_quantization_instance" 193 ]; 194 }; 195 other1 = { 196 enabled = !miOpenReqLibsOnly; 197 targets = [ 198 "device_batchnorm_instance" 199 "device_contraction_bilinear_instance" 200 "device_contraction_scale_instance" 201 "device_elementwise_instance" 202 "device_elementwise_normalization_instance" 203 ]; 204 }; 205 other2 = { 206 enabled = !miOpenReqLibsOnly; 207 targets = [ 208 "device_column_to_image_instance" 209 "device_image_to_column_instance" 210 "device_permute_scale_instance" 211 "device_reduce_instance" 212 ]; 213 }; 214 other3 = { 215 enabled = !miOpenReqLibsOnly; 216 targets = [ 217 "device_normalization_bwd_data_instance" 218 "device_normalization_bwd_gamma_beta_instance" 219 "device_normalization_fwd_instance" 220 "device_softmax_instance" 221 "device_transpose_instance" 222 ]; 223 }; 224 }; 225 tensorOpBuilder = 226 { 227 part, 228 targets, 229 extraCmakeFlags ? [ ], 230 requiredSystemFeatures ? [ "big-parallel" ], 231 enabled ? true, 232 onlyFor ? [ ], 233 }: 234 let 235 supported = 236 enabled 237 && (onlyFor == [ ] || (lib.lists.intersectLists composable_kernel_base.gpuTargets onlyFor) != [ ]); 238 in 239 if supported then 240 (composable_kernel_base.overrideAttrs (old: { 241 inherit requiredSystemFeatures; 242 pname = "composable_kernel${clr.gpuArchSuffix}-${part}"; 243 makeTargets = targets; 244 preBuild = '' 245 echo "Building ${part}" 246 makeFlagsArray+=($makeTargets) 247 substituteInPlace $(find ./ -name "Makefile" -type f) \ 248 --replace-fail '.NOTPARALLEL:' '.UNUSED_NOTPARALLEL:' 249 ''; 250 251 # Compile parallelism adjusted based on available RAM 252 # Never uses less than NIX_BUILD_CORES/4, never uses more than NIX_BUILD_CORES 253 # CK uses an unusually high amount of memory per core in the build step 254 # Nix/nixpkgs doesn't really have any infra to tell it that this build is unusually memory hungry 255 # So, bodge. Otherwise you end up having to build all of ROCm with a low core limit when 256 # it's only this package that has trouble. 257 preConfigure = old.preConfigure or "" + '' 258 MEM_GB_TOTAL=$(awk '/MemTotal/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo) 259 MEM_GB_AVAILABLE=$(awk '/MemAvailable/ { printf "%d \n", $2/1024/1024 }' /proc/meminfo) 260 APPX_GB=$((MEM_GB_AVAILABLE > MEM_GB_TOTAL ? MEM_GB_TOTAL : MEM_GB_AVAILABLE)) 261 MAX_CORES=$((1 + APPX_GB/3)) 262 MAX_CORES=$((MAX_CORES < NIX_BUILD_CORES/3 ? NIX_BUILD_CORES/3 : MAX_CORES)) 263 export NIX_BUILD_CORES="$((NIX_BUILD_CORES > MAX_CORES ? MAX_CORES : NIX_BUILD_CORES))" 264 echo "Picked new core limit NIX_BUILD_CORES=$NIX_BUILD_CORES based on available mem: $APPX_GB GB" 265 cmakeFlagsArray+=( 266 "-DCK_PARALLEL_COMPILE_JOBS=$NIX_BUILD_CORES" 267 ) 268 ''; 269 cmakeFlags = old.cmakeFlags ++ extraCmakeFlags; 270 # Early exit after build phase with success, skips fixups etc 271 # Will get copied back into /build of the final CK 272 postBuild = '' 273 find . -name "*.o" -type f | while read -r file; do 274 mkdir -p "$out/$(dirname "$file")" 275 cp --reflink=auto "$file" "$out/$file" 276 done 277 exit 0 278 ''; 279 meta = old.meta // { 280 broken = false; 281 }; 282 })) 283 else 284 null; 285 composable_kernel_parts = builtins.mapAttrs ( 286 part: targets: tensorOpBuilder (targets // { inherit part; }) 287 ) parts; 288in 289 290composable_kernel_base.overrideAttrs ( 291 finalAttrs: old: { 292 pname = "composable_kernel${clr.gpuArchSuffix}"; 293 parts_dirs = builtins.filter (x: x != null) (builtins.attrValues composable_kernel_parts); 294 disallowedReferences = builtins.filter (x: x != null) (builtins.attrValues composable_kernel_parts); 295 preBuild = '' 296 for dir in $parts_dirs; do 297 find "$dir" -type f -name "*.o" | while read -r file; do 298 # Extract the relative path by removing the output directory prefix 299 rel_path="''${file#"$dir/"}" 300 301 # Create parent directory if it doesn't exist 302 mkdir -p "$(dirname "$rel_path")" 303 304 # Copy the file back to its original location, give it a future timestamp 305 # so make treats it as up to date 306 cp --reflink=auto --no-preserve=all "$file" "$rel_path" 307 touch -d "now +10 hours" "$rel_path" 308 done 309 done 310 ''; 311 passthru = old.passthru // { 312 parts = composable_kernel_parts; 313 }; 314 meta = old.meta // { 315 # Builds without any gfx9 fail 316 broken = !finalAttrs.passthru.anyGfx9Target; 317 }; 318 } 319)