1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 rocmUpdateScript,
6 cmake,
7 rocm-cmake,
8 rocm-merged-llvm,
9 clr,
10 rocminfo,
11 hipify,
12 gitMinimal,
13 gtest,
14 zstd,
15 buildTests ? false,
16 buildExamples ? false,
17 gpuTargets ? (
18 clr.localGpuTargets or [
19 "gfx900"
20 "gfx906"
21 "gfx908"
22 "gfx90a"
23 "gfx942"
24 "gfx1030"
25 "gfx1100"
26 "gfx1101"
27 "gfx1102"
28 ]
29 ),
30}:
31
32stdenv.mkDerivation (finalAttrs: {
33 preBuild = ''
34 echo "This derivation isn't intended to be built directly and only exists to be overridden and built in chunks";
35 exit 1
36 '';
37
38 pname = "composable_kernel_base";
39 # Picked this version over 6.3 because much easier to get to build
40 # and it matches the version torch 2.6 wants
41 version = "6.4.0-unstable-20241220";
42
43 outputs =
44 [
45 "out"
46 ]
47 ++ lib.optionals buildTests [
48 "test"
49 ]
50 ++ lib.optionals buildExamples [
51 "example"
52 ];
53
54 src = fetchFromGitHub {
55 owner = "ROCm";
56 repo = "composable_kernel";
57 rev = "07339c738396ebeae57374771ded4dcf11bddf1e";
58 hash = "sha256-EvEBxlOpQ71BF57VW79WBo/cdxAwTKFXFMiYKyGyyEs=";
59 };
60
61 nativeBuildInputs = [
62 # Deliberately not using ninja
63 # because we're jankily composing build outputs from multiple drvs
64 # ninja won't believe they're up to date
65 gitMinimal
66 cmake
67 rocminfo
68 clr
69 hipify
70 zstd
71 ];
72
73 buildInputs = [
74 rocm-cmake
75 clr
76 zstd
77 ];
78
79 strictDeps = true;
80 enableParallelBuilding = true;
81 env.ROCM_PATH = clr;
82 env.HIP_CLANG_PATH = "${rocm-merged-llvm}/bin";
83
84 cmakeFlags =
85 [
86 "-DCMAKE_MODULE_PATH=${clr}/hip/cmake"
87 "-DCMAKE_BUILD_TYPE=Release"
88 "-DCMAKE_POLICY_DEFAULT_CMP0069=NEW"
89 # "-DDL_KERNELS=ON" # Not needed, slow to build
90 # CK_USE_CODEGEN Required for migraphx which uses device_gemm_multiple_d.hpp
91 # but migraphx requires an incompatible fork of CK and fails anyway
92 # "-DCK_USE_CODEGEN=ON"
93 # It might be worth skipping fp64 in future with this:
94 # "-DDTYPES=fp32;fp16;fp8;bf16;int8"
95 # Manually define CMAKE_INSTALL_<DIR>
96 # See: https://github.com/NixOS/nixpkgs/pull/197838
97 "-DCMAKE_INSTALL_BINDIR=bin"
98 "-DCMAKE_INSTALL_LIBDIR=lib"
99 "-DCMAKE_INSTALL_INCLUDEDIR=include"
100 "-DBUILD_DEV=OFF"
101 "-DROCM_PATH=${clr}"
102 "-DCMAKE_HIP_COMPILER_ROCM_ROOT=${clr}"
103
104 # FP8 can build for 908/90a but very slow build
105 # and produces unusably slow kernels that are huge
106 "-DCK_USE_FP8_ON_UNSUPPORTED_ARCH=OFF"
107 ]
108 ++ lib.optionals (gpuTargets != [ ]) [
109 # We intentionally set GPU_ARCHS and not AMD/GPU_TARGETS
110 # per readme this is required if archs are dissimilar
111 # In rocm-6.3.x not setting any arch flag worked
112 # but setting dissimilar arches always failed
113 "-DGPU_ARCHS=${lib.concatStringsSep ";" gpuTargets}"
114 ]
115 ++ lib.optionals buildTests [
116 "-DGOOGLETEST_DIR=${gtest.src}" # Custom linker names
117 ];
118
119 # No flags to build selectively it seems...
120 postPatch =
121 # Reduce configure time by preventing thousands of clang-tidy targets being added
122 # We will never call them
123 # Never build profiler
124 ''
125 substituteInPlace library/src/utility/CMakeLists.txt library/src/tensor_operation_instance/gpu/CMakeLists.txt \
126 --replace-fail clang_tidy_check '#clang_tidy_check'
127 substituteInPlace CMakeLists.txt \
128 --replace-fail "add_subdirectory(profiler)" ""
129 ''
130 # Optionally remove tests
131 + lib.optionalString (!buildTests) ''
132 substituteInPlace CMakeLists.txt \
133 --replace-fail "add_subdirectory(test)" ""
134 substituteInPlace codegen/CMakeLists.txt \
135 --replace-fail "include(ROCMTest)" ""
136 ''
137 # Optionally remove examples
138 + lib.optionalString (!buildExamples) ''
139 substituteInPlace CMakeLists.txt \
140 --replace-fail "add_subdirectory(example)" ""
141 '';
142
143 postInstall =
144 lib.optionalString buildTests ''
145 mkdir -p $test/bin
146 mv $out/bin/test_* $test/bin
147 ''
148 + lib.optionalString buildExamples ''
149 mkdir -p $example/bin
150 mv $out/bin/example_* $example/bin
151 '';
152
153 passthru.updateScript = rocmUpdateScript {
154 name = finalAttrs.pname;
155 inherit (finalAttrs.src) owner;
156 inherit (finalAttrs.src) repo;
157 };
158
159 passthru.anyGfx9Target = lib.lists.any (lib.strings.hasPrefix "gfx9") gpuTargets;
160
161 meta = with lib; {
162 description = "Performance portable programming model for machine learning tensor operators";
163 homepage = "https://github.com/ROCm/composable_kernel";
164 license = with licenses; [ mit ];
165 teams = [ teams.rocm ];
166 platforms = platforms.linux;
167 broken = true;
168 };
169})