nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 lib,
3 stdenv,
4 config,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # patches
9 replaceVars,
10 addDriverRunpath,
11 cudaPackages,
12
13 # build-system
14 setuptools,
15
16 # nativeBuildInputs
17 cmake,
18 ninja,
19 lit,
20 llvm,
21 writableTmpDirAsHomeHook,
22
23 # buildInputs
24 gtest,
25 libxml2,
26 ncurses,
27 pybind11,
28 zlib,
29
30 # dependencies
31 filelock,
32
33 # passthru
34 python,
35 pytestCheckHook,
36 torchWithRocm,
37 runCommand,
38 triton,
39 rocmPackages,
40
41 cudaSupport ? config.cudaSupport,
42}:
43
44buildPythonPackage rec {
45 pname = "triton";
46 version = "3.5.1";
47 pyproject = true;
48
49 # Remember to bump triton-llvm as well!
50 src = fetchFromGitHub {
51 owner = "triton-lang";
52 repo = "triton";
53 tag = "v${version}";
54 hash = "sha256-dyNRtS1qtU8C/iAf0Udt/1VgtKGSvng1+r2BtvT9RB4=";
55 };
56
57 patches = [
58 (replaceVars ./0001-_build-allow-extra-cc-flags.patch {
59 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
60 })
61 (replaceVars ./0002-nvidia-driver-short-circuit-before-ldconfig.patch {
62 libcudaStubsDir =
63 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null;
64 })
65 ]
66 ++ lib.optionals cudaSupport [
67 (replaceVars ./0003-nvidia-cudart-a-systempath.patch {
68 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
69 })
70 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch {
71 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
72 })
73 ];
74
75 postPatch =
76 # Allow CMake 4
77 # Upstream issue: https://github.com/triton-lang/triton/issues/8245
78 ''
79 substituteInPlace pyproject.toml \
80 --replace-fail "cmake>=3.20,<4.0" "cmake>=3.20"
81 ''
82 # Avoid downloading dependencies remove any downloads
83 + ''
84 substituteInPlace setup.py \
85 --replace-fail "[get_json_package_info()]" "[]" \
86 --replace-fail "[get_llvm_package_info()]" "[]" \
87 --replace-fail 'yield ("triton.profiler", "third_party/proton/proton")' 'pass' \
88 --replace-fail "curr_version.group(1) != version" "False"
89 ''
90 # Use our `cmakeFlags` instead and avoid downloading dependencies
91 + ''
92 substituteInPlace setup.py \
93 --replace-fail \
94 "cmake_args.extend(thirdparty_cmake_args)" \
95 "cmake_args.extend(thirdparty_cmake_args + os.environ.get('cmakeFlags', \"\").split())"
96 ''
97 # Don't fetch googletest
98 + ''
99 substituteInPlace cmake/AddTritonUnitTest.cmake \
100 --replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\
101 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
102 '';
103
104 build-system = [ setuptools ];
105
106 nativeBuildInputs = [
107 cmake
108 ninja
109
110 # Note for future:
111 # These *probably* should go in depsTargetTarget
112 # ...but we cannot test cross right now anyway
113 # because we only support cudaPackages on x86_64-linux atm
114 lit
115 llvm
116
117 # Upstream's setup.py tries to write cache somewhere in ~/
118 writableTmpDirAsHomeHook
119 ];
120
121 cmakeFlags = [
122 (lib.cmakeFeature "LLVM_SYSPATH" "${llvm}")
123
124 # `find_package` is called with `NO_DEFAULT_PATH`
125 # https://cmake.org/cmake/help/latest/command/find_package.html
126 # https://github.com/triton-lang/triton/blob/c3c476f357f1e9768ea4e45aa5c17528449ab9ef/third_party/amd/CMakeLists.txt#L6
127 (lib.cmakeFeature "LLD_DIR" "${lib.getLib llvm}")
128 ];
129
130 buildInputs = [
131 gtest
132 libxml2.dev
133 ncurses
134 pybind11
135 zlib
136 ];
137
138 dependencies = [
139 filelock
140 # triton uses setuptools at runtime:
141 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
142 setuptools
143 ];
144
145 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
146 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
147 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
148 "-Wno-stringop-overread"
149 ];
150
151 preConfigure =
152 # Ensure that the build process uses the requested number of cores
153 ''
154 export MAX_JOBS="$NIX_BUILD_CORES"
155 '';
156
157 env = {
158 TRITON_BUILD_PROTON = "OFF";
159 TRITON_OFFLINE_BUILD = true;
160 }
161 // lib.optionalAttrs cudaSupport {
162 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
163 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
164
165 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
166 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
167 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
168 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
169 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
170 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
171 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
172 };
173
174 pythonRemoveDeps = [
175 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
176 "torch"
177
178 # CLI tools without dist-info
179 "cmake"
180 "lit"
181 ];
182
183 # CMake is run by setup.py instead
184 dontUseCmakeConfigure = true;
185
186 nativeCheckInputs = [ cmake ];
187 preCheck = ''
188 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
189 (cd ./build/temp* ; ctest)
190 '';
191
192 pythonImportsCheck = [
193 "triton"
194 "triton.language"
195 ];
196
197 passthru = {
198 gpuCheck = stdenv.mkDerivation {
199 pname = "triton-pytest";
200 inherit (triton) version src;
201
202 requiredSystemFeatures = [ "cuda" ];
203
204 nativeBuildInputs = [
205 (python.withPackages (ps: [
206 ps.scipy
207 ps.torchWithCuda
208 ps.triton-cuda
209 ]))
210 ];
211
212 dontBuild = true;
213 nativeCheckInputs = [
214 pytestCheckHook
215 writableTmpDirAsHomeHook
216 ];
217
218 doCheck = true;
219
220 preCheck = ''
221 cd python/test/unit
222 '';
223 checkPhase = "pytestCheckPhase";
224
225 installPhase = "touch $out";
226 };
227
228 tests = {
229 # Ultimately, torch is our test suite:
230 inherit torchWithRocm;
231
232 # Test that _get_path_to_hip_runtime_dylib works when ROCm is available at runtime
233 rocm-libamdhip64-path =
234 runCommand "triton-rocm-libamdhip64-path-test"
235 {
236 buildInputs = [
237 triton
238 python
239 rocmPackages.clr
240 ];
241 }
242 ''
243 python -c "
244 import os
245 import triton
246 path = triton.backends.amd.driver._get_path_to_hip_runtime_dylib()
247 print(f'libamdhip64 path: {path}')
248 assert os.path.exists(path)
249 " && touch $out
250 '';
251
252 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda`
253 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck`
254 axpy-cuda =
255 cudaPackages.writeGpuTestPython
256 {
257 libraries = ps: [
258 ps.triton
259 ps.torch-no-triton
260 ];
261
262 gpuCheckArgs.nativeBuildInputs = [
263 # PermissionError: [Errno 13] Permission denied: '/homeless-shelter'
264 writableTmpDirAsHomeHook
265 ];
266 }
267 ''
268 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
269
270 import triton
271 import triton.language as tl
272 import torch
273 import os
274
275 @triton.jit
276 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr):
277 pid = tl.program_id(axis=0)
278 block_start = pid * BLOCK_SIZE
279 offsets = block_start + tl.arange(0, BLOCK_SIZE)
280 mask = offsets < n
281 x = tl.load(x_ptr + offsets, mask=mask)
282 y = tl.load(y_ptr + offsets, mask=mask)
283 output = a * x + y
284 tl.store(out + offsets, output, mask=mask)
285
286 def axpy(a, x, y):
287 output = torch.empty_like(x)
288 assert x.is_cuda and y.is_cuda and output.is_cuda
289 n_elements = output.numel()
290
291 def grid(meta):
292 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
293
294 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024)
295 return output
296
297 if __name__ == "__main__":
298 if os.environ.get("HOME", None) == "/homeless-shelter":
299 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp")
300 if "CC" not in os.environ:
301 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}"
302 torch.manual_seed(0)
303 size = 12345
304 x = torch.rand(size, device='cuda')
305 y = torch.rand(size, device='cuda')
306 output_torch = 3.14 * x + y
307 output_triton = axpy(3.14, x, y)
308 assert output_torch.sub(output_triton).abs().max().item() < 1e-6
309 print("Triton axpy: OK")
310 '';
311 };
312 };
313
314 meta = {
315 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
316 homepage = "https://github.com/triton-lang/triton";
317 platforms = lib.platforms.linux;
318 license = lib.licenses.mit;
319 maintainers = with lib.maintainers; [
320 SomeoneSerge
321 derdennisop
322 ];
323 };
324}