1{
2 lib,
3 addDriverRunpath,
4 buildPythonPackage,
5 cmake,
6 config,
7 cudaPackages,
8 fetchFromGitHub,
9 filelock,
10 gtest,
11 libxml2,
12 lit,
13 llvm,
14 ncurses,
15 ninja,
16 pybind11,
17 python,
18 pytestCheckHook,
19 stdenv,
20 replaceVars,
21 setuptools,
22 torchWithRocm,
23 zlib,
24 cudaSupport ? config.cudaSupport,
25 rocmSupport ? config.rocmSupport,
26 rocmPackages,
27 triton,
28}:
29
30buildPythonPackage {
31 pname = "triton";
32 version = "3.1.0";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "triton-lang";
37 repo = "triton";
38 # latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/
39 rev = "cf34004b8a67d290a962da166f5aa2fc66751326";
40 hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk=";
41 };
42
43 patches =
44 [
45 ./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
46 (replaceVars ./0001-_build-allow-extra-cc-flags.patch {
47 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
48 })
49 (replaceVars ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch {
50 libhipDir = if rocmSupport then "${lib.getLib rocmPackages.clr}/lib" else null;
51 libcudaStubsDir =
52 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null;
53 })
54 ]
55 ++ lib.optionals cudaSupport [
56 (replaceVars ./0003-nvidia-cudart-a-systempath.patch {
57 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
58 })
59 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch {
60 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
61 })
62 ];
63
64 postPatch = ''
65 # Use our `cmakeFlags` instead and avoid downloading dependencies
66 # remove any downloads
67 substituteInPlace python/setup.py \
68 --replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
69 --replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
70 --replace-fail 'packages += ["triton/profiler"]' ""\
71 --replace-fail "curr_version != version" "False"
72
73 # Don't fetch googletest
74 substituteInPlace unittest/CMakeLists.txt \
75 --replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
76 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
77
78 # Patch the source code to make sure it doesn't specify a non-existent PTXAS version.
79 # CUDA 12.6 (the current default/max) tops out at PTXAS version 8.5.
80 # NOTE: This is fixed in `master`:
81 # https://github.com/triton-lang/triton/commit/f48dbc1b106c93144c198fbf3c4f30b2aab9d242
82 substituteInPlace "$NIX_BUILD_TOP/$sourceRoot/third_party/nvidia/backend/compiler.py" \
83 --replace-fail \
84 'return 80 + minor' \
85 'return 80 + min(minor, 5)'
86 '';
87
88 build-system = [ setuptools ];
89
90 nativeBuildInputs = [
91 cmake
92 ninja
93
94 # Note for future:
95 # These *probably* should go in depsTargetTarget
96 # ...but we cannot test cross right now anyway
97 # because we only support cudaPackages on x86_64-linux atm
98 lit
99 llvm
100 ];
101
102 buildInputs = [
103 gtest
104 libxml2.dev
105 ncurses
106 pybind11
107 zlib
108 ];
109
110 dependencies = [
111 filelock
112 # triton uses setuptools at runtime:
113 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
114 setuptools
115 ];
116
117 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
118 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
119 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
120 "-Wno-stringop-overread"
121 ];
122
123 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
124 preConfigure = ''
125 # Ensure that the build process uses the requested number of cores
126 export MAX_JOBS="$NIX_BUILD_CORES"
127
128 # Upstream's setup.py tries to write cache somewhere in ~/
129 export HOME=$(mktemp -d)
130
131 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
132 echo "
133 [build_ext]
134 base-dir=$PWD" >> python/setup.cfg
135
136 # The rest (including buildPhase) is relative to ./python/
137 cd python
138 '';
139
140 env =
141 {
142 TRITON_BUILD_PROTON = "OFF";
143 TRITON_OFFLINE_BUILD = true;
144 }
145 // lib.optionalAttrs cudaSupport {
146 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
147 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
148
149 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
150 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
151 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
152 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
153 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
154 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
155 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
156 };
157
158 pythonRemoveDeps = [
159 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
160 "torch"
161
162 # CLI tools without dist-info
163 "cmake"
164 "lit"
165 ];
166
167 # CMake is run by setup.py instead
168 dontUseCmakeConfigure = true;
169
170 nativeCheckInputs = [ cmake ];
171 preCheck = ''
172 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
173 (cd ./build/temp* ; ctest)
174 '';
175
176 pythonImportsCheck = [
177 "triton"
178 "triton.language"
179 ];
180
181 passthru.gpuCheck = stdenv.mkDerivation {
182 pname = "triton-pytest";
183 inherit (triton) version src;
184
185 requiredSystemFeatures = [ "cuda" ];
186
187 nativeBuildInputs = [
188 (python.withPackages (ps: [
189 ps.scipy
190 ps.torchWithCuda
191 ps.triton-cuda
192 ]))
193 ];
194
195 dontBuild = true;
196 nativeCheckInputs = [ pytestCheckHook ];
197
198 doCheck = true;
199
200 preCheck = ''
201 cd python/test/unit
202 export HOME=$TMPDIR
203 '';
204 checkPhase = "pytestCheckPhase";
205
206 installPhase = "touch $out";
207 };
208
209 passthru.tests = {
210 # Ultimately, torch is our test suite:
211 inherit torchWithRocm;
212
213 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
214 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
215 axpy-cuda =
216 cudaPackages.writeGpuTestPython
217 {
218 libraries = ps: [
219 ps.triton
220 ps.torch-no-triton
221 ];
222 }
223 ''
224 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
225
226 import triton
227 import triton.language as tl
228 import torch
229 import os
230
231 @triton.jit
232 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
233 pid = tl.program_id(axis=0)
234 block_start = pid * BLOCK_SIZE
235 offsets = block_start + tl.arange(0, BLOCK_SIZE)
236 mask = offsets < n
237 x = tl.load(x_ptr + offsets, mask=mask)
238 y = tl.load(y_ptr + offsets, mask=mask)
239 output = a * x + y
240 tl.store(out + offsets, output, mask=mask)
241
242 def axpy(a, x, y):
243 output = torch.empty_like(x)
244 assert x.is_cuda and y.is_cuda and output.is_cuda
245 n_elements = output.numel()
246
247 def grid(meta):
248 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
249
250 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
251 return output
252
253 if __name__ == "__main__":
254 if os.environ.get("HOME", None) == "/homeless-shelter":
255 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
256 if "CC" not in os.environ:
257 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
258 torch.manual_seed(0)
259 size = 12345
260 x = torch.rand(size, device='cuda')
261 y = torch.rand(size, device='cuda')
262 output_torch = 3.14 * x + y
263 output_triton = axpy(3.14, x, y)
264 assert output_torch.sub(output_triton).abs().max().item() < 1e-6
265 print("Triton axpy: OK")
266 '';
267 };
268
269 meta = {
270 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
271 homepage = "https://github.com/triton-lang/triton";
272 platforms = lib.platforms.linux;
273 license = lib.licenses.mit;
274 maintainers = with lib.maintainers; [
275 SomeoneSerge
276 Madouura
277 derdennisop
278 ];
279 };
280}