1{ lib
2, config
3, buildPythonPackage
4, fetchFromGitHub
5, addOpenGLRunpath
6, pytestCheckHook
7, pythonRelaxDepsHook
8, pkgsTargetTarget
9, cmake
10, ninja
11, pybind11
12, gtest
13, zlib
14, ncurses
15, libxml2
16, lit
17, llvm
18, filelock
19, torchWithRocm
20, python
21, cudaPackages
22, cudaSupport ? config.cudaSupport
23}:
24
25let
26 # A time may come we'll want to be cross-friendly
27 #
28 # Short explanation: we need pkgsTargetTarget, because we use string
29 # interpolation instead of buildInputs.
30 #
31 # Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
32 # ptxas compiler. We're not running this ptxas on the build machine, but on
33 # the user's machine, i.e. our Target platform. The second "Target" in
34 # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
35 # be executed on the GPU.
36 # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
37 ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
38in
39buildPythonPackage rec {
40 pname = "triton";
41 version = "2.0.0";
42 format = "setuptools";
43
44 src = fetchFromGitHub {
45 owner = "openai";
46 repo = pname;
47 rev = "v${version}";
48 hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
49 };
50
51 patches = [
52 # TODO: there have been commits upstream aimed at removing the "torch"
53 # circular dependency, but the patches fail to apply on the release
54 # revision. Keeping the link for future reference
55 # Also cf. https://github.com/openai/triton/issues/1374
56
57 # (fetchpatch {
58 # url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
59 # hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
60 # })
61 ] ++ lib.optionals (!cudaSupport) [
62 ./0000-dont-download-ptxas.patch
63 ];
64
65 nativeBuildInputs = [
66 pythonRelaxDepsHook
67 # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
68 cmake
69 ninja
70
71 # Note for future:
72 # These *probably* should go in depsTargetTarget
73 # ...but we cannot test cross right now anyway
74 # because we only support cudaPackages on x86_64-linux atm
75 lit
76 llvm
77 ];
78
79 buildInputs = [
80 gtest
81 libxml2.dev
82 ncurses
83 pybind11
84 zlib
85 ];
86
87 propagatedBuildInputs = [ filelock ];
88
89 postPatch = let
90 # Bash was getting weird without linting,
91 # but basically upstream contains [cc, ..., "-lcuda", ...]
92 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
93 old = [ "-lcuda" ];
94 new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ];
95
96 quote = x: ''"${x}"'';
97 oldStr = lib.concatMapStringsSep ", " quote old;
98 newStr = lib.concatMapStringsSep ", " quote new;
99 in ''
100 # Use our `cmakeFlags` instead and avoid downloading dependencies
101 substituteInPlace python/setup.py \
102 --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
103
104 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
105 substituteInPlace bin/CMakeLists.txt \
106 --replace "add_subdirectory(FileCheck)" ""
107
108 # Don't fetch googletest
109 substituteInPlace unittest/CMakeLists.txt \
110 --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
111 --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
112 '' + lib.optionalString cudaSupport ''
113 # Use our linker flags
114 substituteInPlace python/triton/compiler.py \
115 --replace '${oldStr}' '${newStr}'
116 '';
117
118 # Avoid GLIBCXX mismatch with other cuda-enabled python packages
119 preConfigure = ''
120 # Upstream's setup.py tries to write cache somewhere in ~/
121 export HOME=$(mktemp -d)
122
123 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
124 echo "
125 [build_ext]
126 base-dir=$PWD" >> python/setup.cfg
127
128 # The rest (including buildPhase) is relative to ./python/
129 cd python
130 '' + lib.optionalString cudaSupport ''
131 export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
132 export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
133
134 # Work around download_and_copy_ptxas()
135 mkdir -p $PWD/triton/third_party/cuda/bin
136 ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
137 '';
138
139 # CMake is run by setup.py instead
140 dontUseCmakeConfigure = true;
141
142 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
143 postFixup = lib.optionalString cudaSupport ''
144 rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
145 ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas
146 '';
147
148 checkInputs = [ cmake ]; # ctest
149 dontUseSetuptoolsCheck = true;
150
151 preCheck = ''
152 # build/temp* refers to build_ext.build_temp (looked up in the build logs)
153 (cd /build/source/python/build/temp* ; ctest)
154
155 # For pytestCheckHook
156 cd test/unit
157 '';
158
159 # Circular dependency on torch
160 # pythonImportsCheck = [
161 # "triton"
162 # "triton.language"
163 # ];
164
165 # Ultimately, torch is our test suite:
166 passthru.tests = { inherit torchWithRocm; };
167
168 pythonRemoveDeps = [
169 # Circular dependency, cf. https://github.com/openai/triton/issues/1374
170 "torch"
171
172 # CLI tools without dist-info
173 "cmake"
174 "lit"
175 ];
176
177 meta = with lib; {
178 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
179 homepage = "https://github.com/openai/triton";
180 platforms = lib.platforms.unix;
181 license = licenses.mit;
182 maintainers = with maintainers; [ SomeoneSerge Madouura ];
183 };
184}