1{
2 lib,
3 addDriverRunpath,
4 buildPythonPackage,
5 cmake,
6 config,
7 cudaPackages,
8 fetchFromGitHub,
9 filelock,
10 gtest,
11 libxml2,
12 lit,
13 llvm,
14 ncurses,
15 ninja,
16 pybind11,
17 python,
18 pytestCheckHook,
19 stdenv,
20 substituteAll,
21 setuptools,
22 torchWithRocm,
23 zlib,
24 cudaSupport ? config.cudaSupport,
25 rocmSupport ? config.rocmSupport,
26 rocmPackages,
27 triton,
28}:
29
30buildPythonPackage {
31 pname = "triton";
32 version = "3.1.0";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "triton-lang";
37 repo = "triton";
38 # latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/
39 rev = "cf34004b8a67d290a962da166f5aa2fc66751326";
40 hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk=";
41 };
42
43 patches =
44 [
45 ./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
46 (substituteAll {
47 src = ./0001-_build-allow-extra-cc-flags.patch;
48 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
49 })
50 (substituteAll (
51 {
52 src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch;
53 }
54 // lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; }
55 // lib.optionalAttrs cudaSupport {
56 libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs";
57 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
58 }
59 ))
60 ]
61 ++ lib.optionals cudaSupport [
62 (substituteAll {
63 src = ./0003-nvidia-cudart-a-systempath.patch;
64 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
65 })
66 (substituteAll {
67 src = ./0004-nvidia-allow-static-ptxas-path.patch;
68 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
69 })
70 ];
71
72 postPatch = ''
73 # Use our `cmakeFlags` instead and avoid downloading dependencies
74 # remove any downloads
75 substituteInPlace python/setup.py \
76 --replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
77 --replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
78 --replace-fail 'packages += ["triton/profiler"]' ""\
79 --replace-fail "curr_version != version" "False"
80
81 # Don't fetch googletest
82 substituteInPlace unittest/CMakeLists.txt \
83 --replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
84 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
85 '';
86
87 build-system = [ setuptools ];
88
89 nativeBuildInputs = [
90 cmake
91 ninja
92
93 # Note for future:
94 # These *probably* should go in depsTargetTarget
95 # ...but we cannot test cross right now anyway
96 # because we only support cudaPackages on x86_64-linux atm
97 lit
98 llvm
99 ];
100
101 buildInputs = [
102 gtest
103 libxml2.dev
104 ncurses
105 pybind11
106 zlib
107 ];
108
109 dependencies = [
110 filelock
111 # triton uses setuptools at runtime:
112 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
113 setuptools
114 ];
115
116 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
117 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
118 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
119 "-Wno-stringop-overread"
120 ];
121
122 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
123 preConfigure = ''
124 # Ensure that the build process uses the requested number of cores
125 export MAX_JOBS="$NIX_BUILD_CORES"
126
127 # Upstream's setup.py tries to write cache somewhere in ~/
128 export HOME=$(mktemp -d)
129
130 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
131 echo "
132 [build_ext]
133 base-dir=$PWD" >> python/setup.cfg
134
135 # The rest (including buildPhase) is relative to ./python/
136 cd python
137 '';
138
139 env =
140 {
141 TRITON_BUILD_PROTON = "OFF";
142 TRITON_OFFLINE_BUILD = true;
143 }
144 // lib.optionalAttrs cudaSupport {
145 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
146 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
147
148 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
149 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
150 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
151 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
152 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
153 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
154 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
155 };
156
157 pythonRemoveDeps = [
158 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
159 "torch"
160
161 # CLI tools without dist-info
162 "cmake"
163 "lit"
164 ];
165
166 # CMake is run by setup.py instead
167 dontUseCmakeConfigure = true;
168
169 nativeCheckInputs = [ cmake ];
170 preCheck = ''
171 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
172 (cd ./build/temp* ; ctest)
173 '';
174
175 pythonImportsCheck = [
176 "triton"
177 "triton.language"
178 ];
179
180 passthru.gpuCheck = stdenv.mkDerivation {
181 pname = "triton-pytest";
182 inherit (triton) version src;
183
184 requiredSystemFeatures = [ "cuda" ];
185
186 nativeBuildInputs = [
187 (python.withPackages (ps: [
188 ps.scipy
189 ps.torchWithCuda
190 ps.triton-cuda
191 ]))
192 ];
193
194 dontBuild = true;
195 nativeCheckInputs = [ pytestCheckHook ];
196
197 doCheck = true;
198
199 preCheck = ''
200 cd python/test/unit
201 export HOME=$TMPDIR
202 '';
203 checkPhase = "pytestCheckPhase";
204
205 installPhase = "touch $out";
206 };
207
208 passthru.tests = {
209 # Ultimately, torch is our test suite:
210 inherit torchWithRocm;
211
212 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
213 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
214 axpy-cuda =
215 cudaPackages.writeGpuTestPython
216 {
217 libraries = ps: [
218 ps.triton
219 ps.torch-no-triton
220 ];
221 }
222 ''
223 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
224
225 import triton
226 import triton.language as tl
227 import torch
228 import os
229
230 @triton.jit
231 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
232 pid = tl.program_id(axis=0)
233 block_start = pid * BLOCK_SIZE
234 offsets = block_start + tl.arange(0, BLOCK_SIZE)
235 mask = offsets < n
236 x = tl.load(x_ptr + offsets, mask=mask)
237 y = tl.load(y_ptr + offsets, mask=mask)
238 output = a * x + y
239 tl.store(out + offsets, output, mask=mask)
240
241 def axpy(a, x, y):
242 output = torch.empty_like(x)
243 assert x.is_cuda and y.is_cuda and output.is_cuda
244 n_elements = output.numel()
245
246 def grid(meta):
247 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
248
249 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
250 return output
251
252 if __name__ == "__main__":
253 if os.environ.get("HOME", None) == "/homeless-shelter":
254 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
255 if "CC" not in os.environ:
256 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
257 torch.manual_seed(0)
258 size = 12345
259 x = torch.rand(size, device='cuda')
260 y = torch.rand(size, device='cuda')
261 output_torch = 3.14 * x + y
262 output_triton = axpy(3.14, x, y)
263 assert output_torch.sub(output_triton).abs().max().item() < 1e-6
264 print("Triton axpy: OK")
265 '';
266 };
267
268 meta = {
269 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
270 homepage = "https://github.com/triton-lang/triton";
271 platforms = lib.platforms.linux;
272 license = lib.licenses.mit;
273 maintainers = with lib.maintainers; [
274 SomeoneSerge
275 Madouura
276 derdennisop
277 ];
278 };
279}