1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 rocmUpdateScript,
6 cmake,
7 rocm-cmake,
8 rocm-merged-llvm,
9 clr,
10 rocminfo,
11 python3,
12 hipify,
13 gitMinimal,
14 gtest,
15 zstd,
16 buildTests ? false,
17 buildExamples ? false,
18 gpuTargets ? (
19 clr.localGpuTargets or [
20 "gfx900"
21 "gfx906"
22 "gfx908"
23 "gfx90a"
24 "gfx942"
25 "gfx1030"
26 "gfx1100"
27 "gfx1101"
28 "gfx1102"
29 "gfx1200"
30 "gfx1201"
31 ]
32 ),
33}:
34
35stdenv.mkDerivation (finalAttrs: {
36 preBuild = ''
37 echo "This derivation isn't intended to be built directly and only exists to be overridden and built in chunks";
38 exit 1
39 '';
40
41 pname = "composable_kernel_base";
42 # Picked this version over 6.3 because much easier to get to build
43 # and it matches the version torch 2.6 wants
44 version = "6.4.0-unstable-2024-12-20";
45
46 outputs = [
47 "out"
48 ]
49 ++ lib.optionals buildTests [
50 "test"
51 ]
52 ++ lib.optionals buildExamples [
53 "example"
54 ];
55
56 src = fetchFromGitHub {
57 owner = "ROCm";
58 repo = "composable_kernel";
59 rev = "07339c738396ebeae57374771ded4dcf11bddf1e";
60 hash = "sha256-EvEBxlOpQ71BF57VW79WBo/cdxAwTKFXFMiYKyGyyEs=";
61 };
62
63 nativeBuildInputs = [
64 # Deliberately not using ninja
65 # because we're jankily composing build outputs from multiple drvs
66 # ninja won't believe they're up to date
67 gitMinimal
68 cmake
69 rocminfo
70 clr
71 hipify
72 zstd
73 python3
74 ];
75
76 buildInputs = [
77 rocm-cmake
78 clr
79 zstd
80 ];
81
82 strictDeps = true;
83 enableParallelBuilding = true;
84 env.ROCM_PATH = clr;
85 env.HIP_CLANG_PATH = "${rocm-merged-llvm}/bin";
86
87 cmakeFlags = [
88 "-DCMAKE_MODULE_PATH=${clr}/hip/cmake"
89 "-DCMAKE_BUILD_TYPE=Release"
90 "-DCMAKE_POLICY_DEFAULT_CMP0069=NEW"
91 # "-DDL_KERNELS=ON" # Not needed, slow to build
92 # CK_USE_CODEGEN Required for migraphx which uses device_gemm_multiple_d.hpp
93 # but migraphx requires an incompatible fork of CK and fails anyway
94 # "-DCK_USE_CODEGEN=ON"
95 # It might be worth skipping fp64 in future with this:
96 # "-DDTYPES=fp32;fp16;fp8;bf16;int8"
97 # Manually define CMAKE_INSTALL_<DIR>
98 # See: https://github.com/NixOS/nixpkgs/pull/197838
99 "-DCMAKE_INSTALL_BINDIR=bin"
100 "-DCMAKE_INSTALL_LIBDIR=lib"
101 "-DCMAKE_INSTALL_INCLUDEDIR=include"
102 "-DBUILD_DEV=OFF"
103 "-DROCM_PATH=${clr}"
104 "-DCMAKE_HIP_COMPILER_ROCM_ROOT=${clr}"
105
106 # FP8 can build for 908/90a but very slow build
107 # and produces unusably slow kernels that are huge
108 "-DCK_USE_FP8_ON_UNSUPPORTED_ARCH=OFF"
109 ]
110 ++ lib.optionals (gpuTargets != [ ]) [
111 # We intentionally set GPU_ARCHS and not AMD/GPU_TARGETS
112 # per readme this is required if archs are dissimilar
113 # In rocm-6.3.x not setting any arch flag worked
114 # but setting dissimilar arches always failed
115 "-DGPU_ARCHS=${lib.concatStringsSep ";" gpuTargets}"
116 ]
117 ++ lib.optionals buildTests [
118 "-DGOOGLETEST_DIR=${gtest.src}" # Custom linker names
119 ];
120
121 # No flags to build selectively it seems...
122 postPatch =
123 # Reduce configure time by preventing thousands of clang-tidy targets being added
124 # We will never call them
125 # Never build profiler
126 ''
127 substituteInPlace library/src/utility/CMakeLists.txt library/src/tensor_operation_instance/gpu/CMakeLists.txt \
128 --replace-fail clang_tidy_check '#clang_tidy_check'
129 substituteInPlace CMakeLists.txt \
130 --replace-fail "add_subdirectory(profiler)" ""
131 ''
132 # Optionally remove tests
133 + lib.optionalString (!buildTests) ''
134 substituteInPlace CMakeLists.txt \
135 --replace-fail "add_subdirectory(test)" ""
136 substituteInPlace codegen/CMakeLists.txt \
137 --replace-fail "include(ROCMTest)" ""
138 ''
139 # Optionally remove examples
140 + lib.optionalString (!buildExamples) ''
141 substituteInPlace CMakeLists.txt \
142 --replace-fail "add_subdirectory(example)" ""
143 '';
144
145 postInstall =
146 lib.optionalString buildTests ''
147 mkdir -p $test/bin
148 mv $out/bin/test_* $test/bin
149 ''
150 + lib.optionalString buildExamples ''
151 mkdir -p $example/bin
152 mv $out/bin/example_* $example/bin
153 '';
154
155 passthru.updateScript = rocmUpdateScript {
156 name = finalAttrs.pname;
157 inherit (finalAttrs.src) owner;
158 inherit (finalAttrs.src) repo;
159 };
160
161 passthru.anyGfx9Target = lib.lists.any (lib.strings.hasPrefix "gfx9") gpuTargets;
162
163 meta = with lib; {
164 description = "Performance portable programming model for machine learning tensor operators";
165 homepage = "https://github.com/ROCm/composable_kernel";
166 license = with licenses; [ mit ];
167 teams = [ teams.rocm ];
168 platforms = platforms.linux;
169 broken = true;
170 };
171})