nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 fetchpatch,
6 cmake,
7 rocm-cmake,
8 clr,
9 python3,
10 ninja,
11 xz,
12 writableTmpDirAsHomeHook,
13 pkg-config,
14 gpuTargets ? clr.localGpuTargets or clr.gpuTargets,
15 # for passthru.tests
16 aotriton,
17 hello,
18}:
19let
20 supportedTargets = lib.lists.intersectLists [
21 # aotriton GPU support list:
22 # https://github.com/ROCm/aotriton/blob/main/v2python/gpu_targets.py
23 "gfx90a"
24 "gfx942"
25 "gfx950"
26 # some gfx1100 kernels fail with error: branch size exceeds simm16
27 # but build proceeds and those ops will fallback so it's ok
28 "gfx1100"
29 "gfx1151"
30 "gfx1150"
31 "gfx1201"
32 "gfx1200"
33 ] gpuTargets;
34 anySupportedTargets = supportedTargets != [ ];
35 # Pick a single arbitrary target to speed up shim build when we can't support our target
36 supportedTargets' =
37 if anySupportedTargets then lib.concatStringsSep ";" supportedTargets else "gfx1200";
38in
39stdenv.mkDerivation (finalAttrs: {
40 pname = "aotriton${lib.optionalString (!anySupportedTargets) "-shim"}";
41 version = "0.11.1b";
42
43 src = fetchFromGitHub {
44 owner = "ROCm";
45 repo = "aotriton";
46 tag = finalAttrs.version;
47 hash = "sha256-F7JjyS+6gMdCpOFLldTsNJdVzzVwd6lwW7+V8ZOZfig=";
48 leaveDotGit = true;
49 # fetch all submodules except unused triton submodule that is ~500MB
50 postFetch = ''
51 cd $out
52 git reset --hard HEAD
53 for submodule in $(git config --file .gitmodules --get-regexp path | awk '{print $2}' | grep '^third_party/' | grep -v '^third_party/triton$'); do
54 git submodule update --init --recursive "$submodule"
55 done
56 find "$out" -name .git -print0 | xargs -0 rm -rf
57 '';
58 };
59
60 cmakeBuildType = "RelWithDebInfo";
61 separateDebugInfo = true;
62 __structuredAttrs = true;
63 strictDeps = true;
64 # Only set big-parallel when we are building kernels, no-image mode build is faster
65 requiredSystemFeatures = if anySupportedTargets then [ "big-parallel" ] else [ ];
66
67 env = {
68 AOTRITON_CI_SUPPLIED_SHA1 = finalAttrs.version;
69 ROCM_PATH = "${clr}";
70 CFLAGS = "-w -g1 -gz -Wno-c++11-narrowing";
71 CXXFLAGS = finalAttrs.env.CFLAGS;
72 TRITON_STORE_BINARY_ONLY = 1; # reduce triton disk space usage
73 };
74
75 nativeBuildInputs = [
76 cmake
77 rocm-cmake
78 pkg-config
79 python3
80 ninja
81 clr
82 writableTmpDirAsHomeHook # venv wants to cache in ~
83 ];
84
85 buildInputs = [
86 clr
87 xz
88 ]
89 ++ (with python3.pkgs; [
90 wheel
91 packaging
92 pyyaml
93 numpy
94 filelock
95 iniconfig
96 pluggy
97 pybind11
98 pandas
99 triton
100 ]);
101
102 # Excerpt from README:
103 # Note: do not run ninja separately, due to the limit of the current build system,
104 # ninja install will run the whole build process unconditionally.
105 dontBuild = true;
106 # This builds+installs
107 installPhase = ''
108 runHook preInstall
109 ninja -v install
110 runHook postInstall
111 '';
112 # tests are intended to be ran manually as test/ python scripts and need accelerator
113 doCheck = false;
114 doInstallCheck = false;
115
116 # Need to set absolute paths to VENV and its PYTHON or
117 # build fails with "AOTRITON_INHERIT_SYSTEM_SITE_TRITON is enabled
118 # but triton is not available … no such file or directory"
119 # Set via a preConfigure hook so a valid absolute path can be
120 # picked if nix-shell is used against this package
121 preConfigure = ''
122 cmakeFlagsArray+=(
123 "-DVENV_DIR=$(pwd)/build/venv/"
124 "-DVENV_BIN_PYTHON=$(pwd)/build/venv/bin/python"
125 )
126 '';
127
128 cmakeFlags = [
129 # Disable building kernels if no supported targets are enabled
130 (lib.cmakeBool "AOTRITON_NOIMAGE_MODE" (!anySupportedTargets))
131 # Use preinstalled triton from our python's site-packages
132 (lib.cmakeBool "AOTRITON_INHERIT_SYSTEM_SITE_TRITON" true)
133 # Circular dependency
134 (lib.cmakeBool "AOTRITON_USE_TORCH" false)
135 # FP32 kernels are optional, turn them off to speed up builds and save space
136 # Perf sensitive code should be using BF16 or F16
137 (lib.cmakeBool "AOTRITON_ENABLE_FP32_INPUTS" false)
138 # Avoid kernels being skipped if build host is overloaded
139 (lib.cmakeFeature "AOTRITON_GPU_BUILD_TIMEOUT" "0")
140 # Manually define CMAKE_INSTALL_<DIR>
141 # See: https://github.com/NixOS/nixpkgs/pull/197838
142 (lib.cmakeFeature "CMAKE_INSTALL_BINDIR" "bin")
143 (lib.cmakeFeature "CMAKE_INSTALL_LIBDIR" "lib")
144 (lib.cmakeFeature "CMAKE_INSTALL_INCLUDEDIR" "include")
145 # Note: build will warn "AMDGPU_TARGETS was not set, and system GPU detection was unsuccsesful."
146 # but this can safely be ignored, aotriton uses a different approach to pass targets
147 (lib.cmakeFeature "AOTRITON_TARGET_ARCH" supportedTargets')
148 ];
149
150 passthru.tests = {
151 # regression test that aotriton so doesn't crash in static constructor
152 # currently known to fail on rocm toolchain but fine with default stdenv
153 ld-preload-into-hello = stdenv.mkDerivation {
154 name = "aotriton-basic-load-test";
155 nativeBuildInputs = [ hello ];
156 buildCommand = ''
157 set -e
158 LD_PRELOAD=${
159 aotriton.override {
160 gpuTargets = [ ];
161 }
162 }/lib/libaotriton_v2.so ${hello}/bin/hello > /dev/null
163 echo "ld-preload-into-hello" > $out
164 '';
165 };
166 };
167
168 meta = {
169 description = "ROCm Ahead of Time (AOT) Triton Math Library";
170 homepage = "https://github.com/ROCm/aotriton";
171 license = lib.licenses.mit;
172 teams = [ lib.teams.rocm ];
173 platforms = lib.platforms.linux;
174 # ld: error: unable to insert .comment after .comment
175 broken = stdenv.cc.isClang;
176 };
177})