1{
2 lib,
3 config,
4 buildPythonPackage,
5 fetchFromGitHub,
6 fetchpatch,
7 addOpenGLRunpath,
8 setuptools,
9 pytestCheckHook,
10 pythonRelaxDepsHook,
11 cmake,
12 ninja,
13 pybind11,
14 gtest,
15 zlib,
16 ncurses,
17 libxml2,
18 lit,
19 llvm,
20 filelock,
21 torchWithRocm,
22 python,
23
24 runCommand,
25
26 cudaPackages,
27 cudaSupport ? config.cudaSupport,
28}:
29
30let
31 ptxas = "${cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
32in
33buildPythonPackage rec {
34 pname = "triton";
35 version = "2.1.0";
36 pyproject = true;
37
38 src = fetchFromGitHub {
39 owner = "openai";
40 repo = pname;
41 rev = "v${version}";
42 hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
43 };
44
45 patches =
46 [
47 # fix overflow error
48 (fetchpatch {
49 url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
50 hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
51 })
52 ]
53 ++ lib.optionals (!cudaSupport) [
54 ./0000-dont-download-ptxas.patch
55 # openai-triton wants to get ptxas version even if ptxas is not
56 # used, resulting in ptxas not found error.
57 ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
58 ];
59
60 nativeBuildInputs = [
61 setuptools
62 pythonRelaxDepsHook
63 # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
64 cmake
65 ninja
66
67 # Note for future:
68 # These *probably* should go in depsTargetTarget
69 # ...but we cannot test cross right now anyway
70 # because we only support cudaPackages on x86_64-linux atm
71 lit
72 llvm
73 ];
74
75 buildInputs = [
76 gtest
77 libxml2.dev
78 ncurses
79 pybind11
80 zlib
81 ];
82
83 propagatedBuildInputs = [
84 filelock
85 # openai-triton uses setuptools at runtime:
86 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
87 setuptools
88 ];
89
90 postPatch =
91 let
92 # Bash was getting weird without linting,
93 # but basically upstream contains [cc, ..., "-lcuda", ...]
94 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
95 old = [ "-lcuda" ];
96 new = [
97 "-lcuda"
98 "-L${addOpenGLRunpath.driverLink}"
99 "-L${cudaPackages.cuda_cudart}/lib/stubs/"
100 ];
101
102 quote = x: ''"${x}"'';
103 oldStr = lib.concatMapStringsSep ", " quote old;
104 newStr = lib.concatMapStringsSep ", " quote new;
105 in
106 ''
107 # Use our `cmakeFlags` instead and avoid downloading dependencies
108 substituteInPlace python/setup.py \
109 --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
110
111 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
112 substituteInPlace bin/CMakeLists.txt \
113 --replace "add_subdirectory(FileCheck)" ""
114
115 # Don't fetch googletest
116 substituteInPlace unittest/CMakeLists.txt \
117 --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
118 --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
119 ''
120 + lib.optionalString cudaSupport ''
121 # Use our linker flags
122 substituteInPlace python/triton/common/build.py \
123 --replace '${oldStr}' '${newStr}'
124 '';
125
126 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
127 preConfigure =
128 ''
129 # Ensure that the build process uses the requested number of cores
130 export MAX_JOBS="$NIX_BUILD_CORES"
131
132 # Upstream's setup.py tries to write cache somewhere in ~/
133 export HOME=$(mktemp -d)
134
135 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
136 echo "
137 [build_ext]
138 base-dir=$PWD" >> python/setup.cfg
139
140 # The rest (including buildPhase) is relative to ./python/
141 cd python
142 ''
143 + lib.optionalString cudaSupport ''
144 export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
145 export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
146
147 # Work around download_and_copy_ptxas()
148 mkdir -p $PWD/triton/third_party/cuda/bin
149 ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
150 '';
151
152 # CMake is run by setup.py instead
153 dontUseCmakeConfigure = true;
154
155 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
156 postFixup = lib.optionalString cudaSupport ''
157 rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
158 ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
159 '';
160
161 checkInputs = [ cmake ]; # ctest
162 dontUseSetuptoolsCheck = true;
163
164 preCheck = ''
165 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
166 (cd /build/source/python/build/temp* ; ctest)
167
168 # For pytestCheckHook
169 cd test/unit
170 '';
171
172 # Circular dependency on torch
173 # pythonImportsCheck = [
174 # "triton"
175 # "triton.language"
176 # ];
177
178 # Ultimately, torch is our test suite:
179 passthru.tests = {
180 inherit torchWithRocm;
181 # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
182 # and pythonImportsCheck is commented back.
183 import-triton =
184 runCommand "import-triton"
185 { nativeBuildInputs = [ (python.withPackages (ps: [ ps.openai-triton ])) ]; }
186 ''
187 python << \EOF
188 import triton
189 import triton.language
190 EOF
191 touch "$out"
192 '';
193 };
194
195 pythonRemoveDeps = [
196 # Circular dependency, cf. https://github.com/openai/triton/issues/1374
197 "torch"
198
199 # CLI tools without dist-info
200 "cmake"
201 "lit"
202 ];
203
204 meta = with lib; {
205 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
206 homepage = "https://github.com/openai/triton";
207 platforms = platforms.linux;
208 license = licenses.mit;
209 maintainers = with maintainers; [
210 SomeoneSerge
211 Madouura
212 ];
213 };
214}