1{ lib
2, buildPythonPackage
3, python
4, fetchpatch
5, fetchFromGitHub
6, addOpenGLRunpath
7, cmake
8, cudaPackages
9, llvmPackages
10, pybind11
11, gtest
12, zlib
13, ncurses
14, libxml2
15, lit
16, filelock
17, torchWithRocm
18, pytest
19, pytestCheckHook
20, pythonRelaxDepsHook
21, pkgsTargetTarget
22}:
23
24let
25 pname = "triton";
26 version = "2.0.0";
27
28 inherit (cudaPackages) cuda_cudart backendStdenv;
29
30 # A time may come we'll want to be cross-friendly
31 #
32 # Short explanation: we need pkgsTargetTarget, because we use string
33 # interpolation instead of buildInputs.
34 #
35 # Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
36 # ptxas compiler. We're not running this ptxas on the build machine, but on
37 # the user's machine, i.e. our Target platform. The second "Target" in
38 # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
39 # be executed on the GPU.
40 # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
41 ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas";
42
43 llvm = (llvmPackages.llvm.override {
44 llvmTargetsToBuild = [ "NATIVE" "NVPTX" ];
45 # Upstream CI sets these too:
46 # targetProjects = [ "mlir" ];
47 extraCMakeFlags = [
48 "-DLLVM_INSTALL_UTILS=ON"
49 ];
50 });
51in
52buildPythonPackage {
53 inherit pname version;
54
55 format = "setuptools";
56
57 src = fetchFromGitHub {
58 owner = "openai";
59 repo = pname;
60 rev = "v${version}";
61 hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
62 };
63
64 patches = [
65 # Prerequisite for llvm15 patch
66 (fetchpatch {
67 url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch";
68 hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk=";
69 })
70 (fetchpatch {
71 url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch";
72 hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg=";
73 })
74
75 # Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch
76 # The original patch adds ptxas binary, so we include our own clean copy
77 # Drop with the next update
78 ./llvm15.patch
79
80 # TODO: there have been commits upstream aimed at removing the "torch"
81 # circular dependency, but the patches fail to apply on the release
82 # revision. Keeping the link for future reference
83 # Also cf. https://github.com/openai/triton/issues/1374
84
85 # (fetchpatch {
86 # url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
87 # hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
88 # })
89 ];
90
91 postPatch = ''
92 substituteInPlace python/setup.py \
93 --replace \
94 '= get_thirdparty_packages(triton_cache_path)' \
95 '= os.environ["cmakeFlags"].split()'
96 ''
97 # Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3
98 # Revisit when updating either triton or llvm
99 + ''
100 substituteInPlace CMakeLists.txt \
101 --replace "nvptx" "NVPTX" \
102 --replace "LLVM 11" "LLVM"
103 sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt
104 sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt
105 find -iname '*.td' -exec \
106 sed -i \
107 -e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \
108 -e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \
109 '{}' ';'
110 substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
111 sed -i 's/^include.*$//' unittest/CMakeLists.txt
112 sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt
113 sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt
114 ''
115 # TritonMLIRIR already links MLIRIR. Not transitive?
116 # + ''
117 # echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt
118 # ''
119 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
120 + ''
121 substituteInPlace bin/CMakeLists.txt \
122 --replace "add_subdirectory(FileCheck)" ""
123
124 rm cmake/FindLLVM.cmake
125 ''
126 +
127 (
128 let
129 # Bash was getting weird without linting,
130 # but basically upstream contains [cc, ..., "-lcuda", ...]
131 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
132 old = [ "-lcuda" ];
133 new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ];
134
135 quote = x: ''"${x}"'';
136 oldStr = lib.concatMapStringsSep ", " quote old;
137 newStr = lib.concatMapStringsSep ", " quote new;
138 in
139 ''
140 substituteInPlace python/triton/compiler.py \
141 --replace '${oldStr}' '${newStr}'
142 ''
143 )
144 # Triton seems to be looking up cuda.h
145 + ''
146 sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py
147 '';
148
149 nativeBuildInputs = [
150 cmake
151 pythonRelaxDepsHook
152
153 # Requires torch (circular dependency) and probably needs GPUs:
154 # pytestCheckHook
155
156 # Note for future:
157 # These *probably* should go in depsTargetTarget
158 # ...but we cannot test cross right now anyway
159 # because we only support cudaPackages on x86_64-linux atm
160 lit
161 llvm
162 llvmPackages.mlir
163 ];
164
165 buildInputs = [
166 gtest
167 libxml2.dev
168 ncurses
169 pybind11
170 zlib
171 ];
172
173 propagatedBuildInputs = [
174 filelock
175 ];
176
177 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
178 preConfigure = ''
179 export CC="${backendStdenv.cc}/bin/cc";
180 export CXX="${backendStdenv.cc}/bin/c++";
181
182 # Upstream's setup.py tries to write cache somewhere in ~/
183 export HOME=$TMPDIR
184
185 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
186 echo "
187 [build_ext]
188 base-dir=$PWD" >> python/setup.cfg
189
190 # The rest (including buildPhase) is relative to ./python/
191 cd python/
192
193 # Work around download_and_copy_ptxas()
194 dst_cuda="$PWD/triton/third_party/cuda/bin"
195 mkdir -p "$dst_cuda"
196 ln -s "${ptxas}" "$dst_cuda/"
197 '';
198
199 # CMake is run by setup.py instead
200 dontUseCmakeConfigure = true;
201 cmakeFlags = [
202 "-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir"
203 ];
204
205 postFixup =
206 let
207 ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas";
208 in
209 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
210 ''
211 rm -f ${ptxasDestination}
212 ln -s ${ptxas} ${ptxasDestination}
213 '';
214
215 checkInputs = [
216 cmake # ctest
217 ];
218 dontUseSetuptoolsCheck = true;
219 preCheck =
220 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
221 ''
222 (cd /build/source/python/build/temp* ; ctest)
223 '' # For pytestCheckHook
224 + ''
225 cd test/unit
226 '';
227 pythonImportsCheck = [
228 # Circular dependency on torch
229 # "triton"
230 # "triton.language"
231 ];
232
233 # Ultimately, torch is our test suite:
234 passthru.tests = {
235 inherit torchWithRocm;
236 };
237
238 pythonRemoveDeps = [
239 # Circular dependency, cf. https://github.com/openai/triton/issues/1374
240 "torch"
241
242 # CLI tools without dist-info
243 "cmake"
244 "lit"
245 ];
246 meta = with lib; {
247 description = "Development repository for the Triton language and compiler";
248 homepage = "https://github.com/openai/triton/";
249 platforms = lib.platforms.unix;
250 license = licenses.mit;
251 maintainers = with maintainers; [ SomeoneSerge ];
252 };
253}