1{
2 lib,
3 stdenv,
4 addDriverRunpath,
5 cudaPackages,
6 buildPythonPackage,
7 fetchurl,
8 isPy38,
9 isPy39,
10 isPy310,
11 isPy311,
12 python,
13 autoPatchelfHook,
14 filelock,
15 lit,
16 zlib,
17}:
18
19buildPythonPackage rec {
20 pname = "triton";
21 version = "2.1.0";
22 format = "wheel";
23
24 src =
25 let
26 pyVerNoDot = lib.replaceStrings [ "." ] [ "" ] python.pythonVersion;
27 unsupported = throw "Unsupported system";
28 srcs = (import ./binary-hashes.nix version)."${stdenv.system}-${pyVerNoDot}" or unsupported;
29 in
30 fetchurl srcs;
31
32 disabled = !(isPy38 || isPy39 || isPy310 || isPy311);
33
34 pythonRemoveDeps = [
35 "cmake"
36 # torch and triton refer to each other so this hook is included to mitigate that.
37 "torch"
38 ];
39
40 buildInputs = [ zlib ];
41
42 nativeBuildInputs = [
43 autoPatchelfHook
44 ];
45
46 propagatedBuildInputs = [
47 filelock
48 lit
49 zlib
50 ];
51
52 dontStrip = true;
53
54 # If this breaks, consider replacing with "${cuda_nvcc}/bin/ptxas"
55 postFixup =
56 ''
57 chmod +x "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas"
58 ''
59 + (
60 let
61 # Bash was getting weird without linting,
62 # but basically upstream contains [cc, ..., "-lcuda", ...]
63 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
64 old = [ "-lcuda" ];
65 new = [
66 "-lcuda"
67 "-L${addDriverRunpath.driverLink}"
68 "-L${cudaPackages.cuda_cudart}/lib/stubs/"
69 ];
70
71 quote = x: ''"${x}"'';
72 oldStr = lib.concatMapStringsSep ", " quote old;
73 newStr = lib.concatMapStringsSep ", " quote new;
74 in
75 ''
76 substituteInPlace $out/${python.sitePackages}/triton/common/build.py \
77 --replace '${oldStr}' '${newStr}'
78 ''
79 );
80
81 meta = with lib; {
82 description = "Language and compiler for custom Deep Learning operations";
83 homepage = "https://github.com/triton-lang/triton/";
84 changelog = "https://github.com/triton-lang/triton/releases/tag/v${version}";
85 # Includes NVIDIA's ptxas, but redistributions of the binary are not limited.
86 # https://docs.nvidia.com/cuda/eula/index.html
87 # triton's license is MIT.
88 # triton-bin includes ptxas binary, therefore unfreeRedistributable is set.
89 license = with licenses; [
90 unfreeRedistributable
91 mit
92 ];
93 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
94 platforms = [ "x86_64-linux" ];
95 maintainers = with maintainers; [ junjihashimoto ];
96 };
97}