1{
2 lib,
3 config,
4 addDriverRunpath,
5 buildPythonPackage,
6 fetchFromGitHub,
7 fetchpatch,
8 setuptools,
9 cmake,
10 ninja,
11 pybind11,
12 gtest,
13 zlib,
14 ncurses,
15 libxml2,
16 lit,
17 llvm,
18 filelock,
19 torchWithRocm,
20 python,
21
22 runCommand,
23
24 cudaPackages,
25 cudaSupport ? config.cudaSupport,
26}:
27
28let
29 ptxas = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
30in
31buildPythonPackage rec {
32 pname = "triton";
33 version = "2.1.0";
34 pyproject = true;
35
36 src = fetchFromGitHub {
37 owner = "openai";
38 repo = pname;
39 rev = "v${version}";
40 hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
41 };
42
43 patches =
44 [
45 # fix overflow error
46 (fetchpatch {
47 url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
48 hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
49 })
50
51 # Upstream startded pinning CUDA version and falling back to downloading from Conda
52 # in https://github.com/triton-lang/triton/pull/1574/files#diff-eb8b42d9346d0a5d371facf21a8bfa2d16fb49e213ae7c21f03863accebe0fcfR120-R123
53 ./0000-dont-download-ptxas.patch
54 ]
55 ++ lib.optionals (!cudaSupport) [
56 # triton wants to get ptxas version even if ptxas is not
57 # used, resulting in ptxas not found error.
58 ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
59 ];
60
61 postPatch =
62 let
63 quote = x: ''"${x}"'';
64 subs.ldFlags =
65 let
66 # Bash was getting weird without linting,
67 # but basically upstream contains [cc, ..., "-lcuda", ...]
68 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
69 old = [ "-lcuda" ];
70 new = [
71 "-lcuda"
72 "-L${addDriverRunpath.driverLink}"
73 "-L${cudaPackages.cuda_cudart}/lib/stubs/"
74 ];
75 in
76 {
77 oldStr = lib.concatMapStringsSep ", " quote old;
78 newStr = lib.concatMapStringsSep ", " quote new;
79 };
80 in
81 ''
82 # Use our `cmakeFlags` instead and avoid downloading dependencies
83 substituteInPlace python/setup.py \
84 --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
85
86 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
87 substituteInPlace bin/CMakeLists.txt \
88 --replace "add_subdirectory(FileCheck)" ""
89
90 # Don't fetch googletest
91 substituteInPlace unittest/CMakeLists.txt \
92 --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
93 --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
94
95 cat << \EOF >> python/triton/common/build.py
96 def libcuda_dirs():
97 return [ "${addDriverRunpath.driverLink}/lib" ]
98 EOF
99 ''
100 + lib.optionalString cudaSupport ''
101 # Use our linker flags
102 substituteInPlace python/triton/common/build.py \
103 --replace '${subs.ldFlags.oldStr}' '${subs.ldFlags.newStr}'
104 '';
105
106 nativeBuildInputs = [
107 setuptools
108 # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
109 cmake
110 ninja
111
112 # Note for future:
113 # These *probably* should go in depsTargetTarget
114 # ...but we cannot test cross right now anyway
115 # because we only support cudaPackages on x86_64-linux atm
116 lit
117 llvm
118 ];
119
120 buildInputs = [
121 gtest
122 libxml2.dev
123 ncurses
124 pybind11
125 zlib
126 ];
127
128 propagatedBuildInputs = [
129 filelock
130 # triton uses setuptools at runtime:
131 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
132 setuptools
133 ];
134
135 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [
136 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch.
137 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042
138 "-Wno-stringop-overread"
139 ];
140
141 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
142 preConfigure =
143 ''
144 # Ensure that the build process uses the requested number of cores
145 export MAX_JOBS="$NIX_BUILD_CORES"
146
147 # Upstream's setup.py tries to write cache somewhere in ~/
148 export HOME=$(mktemp -d)
149
150 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
151 echo "
152 [build_ext]
153 base-dir=$PWD" >> python/setup.cfg
154
155 # The rest (including buildPhase) is relative to ./python/
156 cd python
157 ''
158 + lib.optionalString cudaSupport ''
159 export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
160 export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
161
162 # Work around download_and_copy_ptxas()
163 mkdir -p $PWD/triton/third_party/cuda/bin
164 ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
165 '';
166
167 # CMake is run by setup.py instead
168 dontUseCmakeConfigure = true;
169
170 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
171 postFixup = lib.optionalString cudaSupport ''
172 rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
173 ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
174 '';
175
176 checkInputs = [ cmake ]; # ctest
177 dontUseSetuptoolsCheck = true;
178
179 preCheck = ''
180 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
181 (cd /build/source/python/build/temp* ; ctest)
182
183 # For pytestCheckHook
184 cd test/unit
185 '';
186
187 # Circular dependency on torch
188 # pythonImportsCheck = [
189 # "triton"
190 # "triton.language"
191 # ];
192
193 # Ultimately, torch is our test suite:
194 passthru.tests = {
195 inherit torchWithRocm;
196 # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
197 # and pythonImportsCheck is commented back.
198 import-triton =
199 runCommand "import-triton"
200 { nativeBuildInputs = [ (python.withPackages (ps: [ ps.triton ])) ]; }
201 ''
202 python << \EOF
203 import triton
204 import triton.language
205 EOF
206 touch "$out"
207 '';
208 };
209
210 pythonRemoveDeps = [
211 # Circular dependency, cf. https://github.com/openai/triton/issues/1374
212 "torch"
213
214 # CLI tools without dist-info
215 "cmake"
216 "lit"
217 ];
218
219 meta = with lib; {
220 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
221 homepage = "https://github.com/openai/triton";
222 platforms = platforms.linux;
223 license = licenses.mit;
224 maintainers = with maintainers; [
225 SomeoneSerge
226 Madouura
227 ];
228 };
229}