1{
2 lib,
3 addDriverRunpath,
4 buildPythonPackage,
5 cmake,
6 config,
7 cudaPackages,
8 fetchFromGitHub,
9 filelock,
10 gtest,
11 libxml2,
12 lit,
13 llvm,
14 ncurses,
15 ninja,
16 pybind11,
17 python,
18 pytestCheckHook,
19 writableTmpDirAsHomeHook,
20 stdenv,
21 replaceVars,
22 setuptools,
23 torchWithRocm,
24 zlib,
25 cudaSupport ? config.cudaSupport,
26 runCommand,
27 rocmPackages,
28 triton,
29}:
30
31buildPythonPackage rec {
32 pname = "triton";
33 version = "3.4.0";
34 pyproject = true;
35
36 # Remember to bump triton-llvm as well!
37 src = fetchFromGitHub {
38 owner = "triton-lang";
39 repo = "triton";
40 tag = "v${version}";
41 hash = "sha256-78s9ke6UV7Tnx3yCr0QZcVDqQELR4XoGgJY7olNJmjk=";
42 };
43
44 patches = [
45 (replaceVars ./0001-_build-allow-extra-cc-flags.patch {
46 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
47 })
48 (replaceVars ./0002-nvidia-driver-short-circuit-before-ldconfig.patch {
49 libcudaStubsDir =
50 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null;
51 })
52 # Upstream PR: https://github.com/triton-lang/triton/pull/7959
53 ./0005-amd-search-env-paths.patch
54 ]
55 ++ lib.optionals cudaSupport [
56 (replaceVars ./0003-nvidia-cudart-a-systempath.patch {
57 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
58 })
59 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch {
60 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
61 })
62 ];
63
64 postPatch =
65 # Allow CMake 4
66 # Upstream issue: https://github.com/triton-lang/triton/issues/8245
67 ''
68 substituteInPlace pyproject.toml \
69 --replace-fail "cmake>=3.20,<4.0" "cmake>=3.20"
70 ''
71 # Avoid downloading dependencies remove any downloads
72 + ''
73 substituteInPlace setup.py \
74 --replace-fail "[get_json_package_info()]" "[]" \
75 --replace-fail "[get_llvm_package_info()]" "[]" \
76 --replace-fail 'yield ("triton.profiler", "third_party/proton/proton")' 'pass' \
77 --replace-fail "curr_version.group(1) != version" "False"
78 ''
79 # Use our `cmakeFlags` instead and avoid downloading dependencies
80 + ''
81 substituteInPlace setup.py \
82 --replace-fail \
83 "cmake_args.extend(thirdparty_cmake_args)" \
84 "cmake_args.extend(thirdparty_cmake_args + os.environ.get('cmakeFlags', \"\").split())"
85 ''
86 # Don't fetch googletest
87 + ''
88 substituteInPlace cmake/AddTritonUnitTest.cmake \
89 --replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\
90 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
91 ''
92 # Don't use FHS path for ROCm LLD
93 # Remove this after `[AMD] Use lld library API #7548` makes it into a release
94 + ''
95 substituteInPlace third_party/amd/backend/compiler.py \
96 --replace-fail 'lld = Path("/opt/rocm/llvm/bin/ld.lld")' \
97 "import os;lld = Path(os.getenv('HIP_PATH', '/opt/rocm/')"' + "/llvm/bin/ld.lld")'
98 '';
99
100 build-system = [ setuptools ];
101
102 nativeBuildInputs = [
103 cmake
104 ninja
105
106 # Note for future:
107 # These *probably* should go in depsTargetTarget
108 # ...but we cannot test cross right now anyway
109 # because we only support cudaPackages on x86_64-linux atm
110 lit
111 llvm
112
113 # Upstream's setup.py tries to write cache somewhere in ~/
114 writableTmpDirAsHomeHook
115 ];
116
117 cmakeFlags = [
118 (lib.cmakeFeature "LLVM_SYSPATH" "${llvm}")
119 ];
120
121 buildInputs = [
122 gtest
123 libxml2.dev
124 ncurses
125 pybind11
126 zlib
127 ];
128
129 dependencies = [
130 filelock
131 # triton uses setuptools at runtime:
132 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
133 setuptools
134 ];
135
136 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
137 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
138 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
139 "-Wno-stringop-overread"
140 ];
141
142 preConfigure =
143 # Ensure that the build process uses the requested number of cores
144 ''
145 export MAX_JOBS="$NIX_BUILD_CORES"
146 '';
147
148 env = {
149 TRITON_BUILD_PROTON = "OFF";
150 TRITON_OFFLINE_BUILD = true;
151 }
152 // lib.optionalAttrs cudaSupport {
153 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
154 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
155
156 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
157 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
158 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
159 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
160 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
161 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
162 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
163 };
164
165 pythonRemoveDeps = [
166 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
167 "torch"
168
169 # CLI tools without dist-info
170 "cmake"
171 "lit"
172 ];
173
174 # CMake is run by setup.py instead
175 dontUseCmakeConfigure = true;
176
177 nativeCheckInputs = [ cmake ];
178 preCheck = ''
179 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
180 (cd ./build/temp* ; ctest)
181 '';
182
183 pythonImportsCheck = [
184 "triton"
185 "triton.language"
186 ];
187
188 passthru.gpuCheck = stdenv.mkDerivation {
189 pname = "triton-pytest";
190 inherit (triton) version src;
191
192 requiredSystemFeatures = [ "cuda" ];
193
194 nativeBuildInputs = [
195 (python.withPackages (ps: [
196 ps.scipy
197 ps.torchWithCuda
198 ps.triton-cuda
199 ]))
200 ];
201
202 dontBuild = true;
203 nativeCheckInputs = [
204 pytestCheckHook
205 writableTmpDirAsHomeHook
206 ];
207
208 doCheck = true;
209
210 preCheck = ''
211 cd python/test/unit
212 '';
213 checkPhase = "pytestCheckPhase";
214
215 installPhase = "touch $out";
216 };
217
218 passthru.tests = {
219 # Ultimately, torch is our test suite:
220 inherit torchWithRocm;
221
222 # Test that _get_path_to_hip_runtime_dylib works when ROCm is available at runtime
223 rocm-libamdhip64-path =
224 runCommand "triton-rocm-libamdhip64-path-test"
225 {
226 buildInputs = [
227 triton
228 python
229 rocmPackages.clr
230 ];
231 }
232 ''
233 python -c "
234 import os
235 import triton
236 path = triton.backends.amd.driver._get_path_to_hip_runtime_dylib()
237 print(f'libamdhip64 path: {path}')
238 assert os.path.exists(path)
239 " && touch $out
240 '';
241
242 # Test that path_to_rocm_lld works when ROCm is available at runtime
243 # Remove this after `[AMD] Use lld library API #7548` makes it into a release
244 rocm-lld-path =
245 runCommand "triton-rocm-lld-test"
246 {
247 buildInputs = [
248 triton
249 python
250 rocmPackages.clr
251 ];
252 }
253 ''
254 python -c "
255 import os
256 import triton
257 path = triton.backends.backends['amd'].compiler.path_to_rocm_lld()
258 print(f'ROCm LLD path: {path}')
259 assert os.path.exists(path)
260 " && touch $out
261 '';
262
263 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
264 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
265 axpy-cuda =
266 cudaPackages.writeGpuTestPython
267 {
268 libraries = ps: [
269 ps.triton
270 ps.torch-no-triton
271 ];
272 }
273 ''
274 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
275
276 import triton
277 import triton.language as tl
278 import torch
279 import os
280
281 @triton.jit
282 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
283 pid = tl.program_id(axis=0)
284 block_start = pid * BLOCK_SIZE
285 offsets = block_start + tl.arange(0, BLOCK_SIZE)
286 mask = offsets < n
287 x = tl.load(x_ptr + offsets, mask=mask)
288 y = tl.load(y_ptr + offsets, mask=mask)
289 output = a * x + y
290 tl.store(out + offsets, output, mask=mask)
291
292 def axpy(a, x, y):
293 output = torch.empty_like(x)
294 assert x.is_cuda and y.is_cuda and output.is_cuda
295 n_elements = output.numel()
296
297 def grid(meta):
298 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
299
300 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
301 return output
302
303 if __name__ == "__main__":
304 if os.environ.get("HOME", None) == "/homeless-shelter":
305 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
306 if "CC" not in os.environ:
307 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
308 torch.manual_seed(0)
309 size = 12345
310 x = torch.rand(size, device='cuda')
311 y = torch.rand(size, device='cuda')
312 output_torch = 3.14 * x + y
313 output_triton = axpy(3.14, x, y)
314 assert output_torch.sub(output_triton).abs().max().item() < 1e-6
315 print("Triton axpy: OK")
316 '';
317 };
318
319 meta = {
320 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
321 homepage = "https://github.com/triton-lang/triton";
322 platforms = lib.platforms.linux;
323 license = lib.licenses.mit;
324 maintainers = with lib.maintainers; [
325 SomeoneSerge
326 derdennisop
327 ];
328 };
329}