1{
2 lib,
3 addDriverRunpath,
4 buildPythonPackage,
5 cmake,
6 config,
7 cudaPackages,
8 fetchFromGitHub,
9 filelock,
10 gtest,
11 libxml2,
12 lit,
13 llvm,
14 ncurses,
15 ninja,
16 pybind11,
17 python,
18 pytestCheckHook,
19 writableTmpDirAsHomeHook,
20 stdenv,
21 replaceVars,
22 setuptools,
23 torchWithRocm,
24 zlib,
25 cudaSupport ? config.cudaSupport,
26 rocmSupport ? config.rocmSupport,
27 rocmPackages,
28 triton,
29}:
30
31buildPythonPackage rec {
32 pname = "triton";
33 version = "3.4.0";
34 pyproject = true;
35
36 # Remember to bump triton-llvm as well!
37 src = fetchFromGitHub {
38 owner = "triton-lang";
39 repo = "triton";
40 tag = "v${version}";
41 hash = "sha256-78s9ke6UV7Tnx3yCr0QZcVDqQELR4XoGgJY7olNJmjk=";
42 };
43
44 patches = [
45 (replaceVars ./0001-_build-allow-extra-cc-flags.patch {
46 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
47 })
48 (replaceVars ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch {
49 libhipDir = if rocmSupport then "${lib.getLib rocmPackages.clr}/lib" else null;
50 libcudaStubsDir =
51 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null;
52 })
53 ]
54 ++ lib.optionals cudaSupport [
55 (replaceVars ./0003-nvidia-cudart-a-systempath.patch {
56 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
57 })
58 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch {
59 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
60 })
61 ];
62
63 postPatch =
64 # Avoid downloading dependencies remove any downloads
65 ''
66 substituteInPlace setup.py \
67 --replace-fail "[get_json_package_info()]" "[]" \
68 --replace-fail "[get_llvm_package_info()]" "[]" \
69 --replace-fail 'yield ("triton.profiler", "third_party/proton/proton")' 'pass' \
70 --replace-fail "curr_version.group(1) != version" "False"
71 ''
72 # Use our `cmakeFlags` instead and avoid downloading dependencies
73 + ''
74 substituteInPlace setup.py \
75 --replace-fail \
76 "cmake_args.extend(thirdparty_cmake_args)" \
77 "cmake_args.extend(thirdparty_cmake_args + os.environ.get('cmakeFlags', \"\").split())"
78 ''
79 # Don't fetch googletest
80 + ''
81 substituteInPlace cmake/AddTritonUnitTest.cmake \
82 --replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\
83 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
84 '';
85
86 build-system = [ setuptools ];
87
88 nativeBuildInputs = [
89 cmake
90 ninja
91
92 # Note for future:
93 # These *probably* should go in depsTargetTarget
94 # ...but we cannot test cross right now anyway
95 # because we only support cudaPackages on x86_64-linux atm
96 lit
97 llvm
98
99 # Upstream's setup.py tries to write cache somewhere in ~/
100 writableTmpDirAsHomeHook
101 ];
102
103 cmakeFlags = [
104 (lib.cmakeFeature "LLVM_SYSPATH" "${llvm}")
105 ];
106
107 buildInputs = [
108 gtest
109 libxml2.dev
110 ncurses
111 pybind11
112 zlib
113 ];
114
115 dependencies = [
116 filelock
117 # triton uses setuptools at runtime:
118 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
119 setuptools
120 ];
121
122 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
123 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
124 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
125 "-Wno-stringop-overread"
126 ];
127
128 preConfigure =
129 # Ensure that the build process uses the requested number of cores
130 ''
131 export MAX_JOBS="$NIX_BUILD_CORES"
132 '';
133
134 env = {
135 TRITON_BUILD_PROTON = "OFF";
136 TRITON_OFFLINE_BUILD = true;
137 }
138 // lib.optionalAttrs cudaSupport {
139 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
140 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
141
142 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
143 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
144 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
145 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
146 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
147 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
148 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
149 };
150
151 pythonRemoveDeps = [
152 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
153 "torch"
154
155 # CLI tools without dist-info
156 "cmake"
157 "lit"
158 ];
159
160 # CMake is run by setup.py instead
161 dontUseCmakeConfigure = true;
162
163 nativeCheckInputs = [ cmake ];
164 preCheck = ''
165 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
166 (cd ./build/temp* ; ctest)
167 '';
168
169 pythonImportsCheck = [
170 "triton"
171 "triton.language"
172 ];
173
174 passthru.gpuCheck = stdenv.mkDerivation {
175 pname = "triton-pytest";
176 inherit (triton) version src;
177
178 requiredSystemFeatures = [ "cuda" ];
179
180 nativeBuildInputs = [
181 (python.withPackages (ps: [
182 ps.scipy
183 ps.torchWithCuda
184 ps.triton-cuda
185 ]))
186 ];
187
188 dontBuild = true;
189 nativeCheckInputs = [
190 pytestCheckHook
191 writableTmpDirAsHomeHook
192 ];
193
194 doCheck = true;
195
196 preCheck = ''
197 cd python/test/unit
198 '';
199 checkPhase = "pytestCheckPhase";
200
201 installPhase = "touch $out";
202 };
203
204 passthru.tests = {
205 # Ultimately, torch is our test suite:
206 inherit torchWithRocm;
207
208 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
209 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
210 axpy-cuda =
211 cudaPackages.writeGpuTestPython
212 {
213 libraries = ps: [
214 ps.triton
215 ps.torch-no-triton
216 ];
217 }
218 ''
219 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
220
221 import triton
222 import triton.language as tl
223 import torch
224 import os
225
226 @triton.jit
227 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
228 pid = tl.program_id(axis=0)
229 block_start = pid * BLOCK_SIZE
230 offsets = block_start + tl.arange(0, BLOCK_SIZE)
231 mask = offsets < n
232 x = tl.load(x_ptr + offsets, mask=mask)
233 y = tl.load(y_ptr + offsets, mask=mask)
234 output = a * x + y
235 tl.store(out + offsets, output, mask=mask)
236
237 def axpy(a, x, y):
238 output = torch.empty_like(x)
239 assert x.is_cuda and y.is_cuda and output.is_cuda
240 n_elements = output.numel()
241
242 def grid(meta):
243 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
244
245 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
246 return output
247
248 if __name__ == "__main__":
249 if os.environ.get("HOME", None) == "/homeless-shelter":
250 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
251 if "CC" not in os.environ:
252 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
253 torch.manual_seed(0)
254 size = 12345
255 x = torch.rand(size, device='cuda')
256 y = torch.rand(size, device='cuda')
257 output_torch = 3.14 * x + y
258 output_triton = axpy(3.14, x, y)
259 assert output_torch.sub(output_triton).abs().max().item() < 1e-6
260 print("Triton axpy: OK")
261 '';
262 };
263
264 meta = {
265 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
266 homepage = "https://github.com/triton-lang/triton";
267 platforms = lib.platforms.linux;
268 license = lib.licenses.mit;
269 maintainers = with lib.maintainers; [
270 SomeoneSerge
271 Madouura
272 derdennisop
273 ];
274 };
275}