1{
2 lib,
3 stdenv,
4 addDriverRunpath,
5 cudaPackages,
6 buildPythonPackage,
7 fetchurl,
8 python,
9 pythonOlder,
10 autoPatchelfHook,
11 filelock,
12 lit,
13 zlib,
14}:
15
16buildPythonPackage rec {
17 pname = "triton";
18 version = "3.1.0";
19 format = "wheel";
20
21 src =
22 let
23 pyVerNoDot = lib.replaceStrings [ "." ] [ "" ] python.pythonVersion;
24 unsupported = throw "Unsupported system";
25 srcs = (import ./binary-hashes.nix version)."${stdenv.system}-${pyVerNoDot}" or unsupported;
26 in
27 fetchurl srcs;
28
29 disabled = pythonOlder "3.8";
30
31 pythonRemoveDeps = [
32 "cmake"
33 # torch and triton refer to each other so this hook is included to mitigate that.
34 "torch"
35 ];
36
37 buildInputs = [ zlib ];
38
39 nativeBuildInputs = [
40 autoPatchelfHook
41 ];
42
43 propagatedBuildInputs = [
44 filelock
45 lit
46 zlib
47 ];
48
49 dontStrip = true;
50
51 # If this breaks, consider replacing with "${cuda_nvcc}/bin/ptxas"
52 postFixup =
53 ''
54 chmod +x "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas"
55 ''
56 + (
57 let
58 # Bash was getting weird without linting,
59 # but basically upstream contains [cc, ..., "-lcuda", ...]
60 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
61 old = [ "-lcuda" ];
62 new = [
63 "-lcuda"
64 "-L${addDriverRunpath.driverLink}"
65 "-L${cudaPackages.cuda_cudart}/lib/stubs/"
66 ];
67
68 quote = x: ''"${x}"'';
69 oldStr = lib.concatMapStringsSep ", " quote old;
70 newStr = lib.concatMapStringsSep ", " quote new;
71 in
72 ''
73 substituteInPlace $out/${python.sitePackages}/triton/common/build.py \
74 --replace '${oldStr}' '${newStr}'
75 ''
76 );
77
78 meta = with lib; {
79 description = "Language and compiler for custom Deep Learning operations";
80 homepage = "https://github.com/triton-lang/triton/";
81 changelog = "https://github.com/triton-lang/triton/releases/tag/v${version}";
82 # Includes NVIDIA's ptxas, but redistributions of the binary are not limited.
83 # https://docs.nvidia.com/cuda/eula/index.html
84 # triton's license is MIT.
85 # triton-bin includes ptxas binary, therefore unfreeRedistributable is set.
86 license = with licenses; [
87 unfreeRedistributable
88 mit
89 ];
90 sourceProvenance = with sourceTypes; [ binaryNativeCode ];
91 platforms = [ "x86_64-linux" ];
92 maintainers = with maintainers; [ junjihashimoto ];
93 };
94}