Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
1{ 2 lib, 3 config, 4 addDriverRunpath, 5 buildPythonPackage, 6 fetchFromGitHub, 7 fetchpatch, 8 setuptools, 9 cmake, 10 ninja, 11 pybind11, 12 gtest, 13 zlib, 14 ncurses, 15 libxml2, 16 lit, 17 llvm, 18 filelock, 19 torchWithRocm, 20 python, 21 22 runCommand, 23 24 cudaPackages, 25 cudaSupport ? config.cudaSupport, 26}: 27 28let 29 ptxas = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 30in 31buildPythonPackage rec { 32 pname = "triton"; 33 version = "2.1.0"; 34 pyproject = true; 35 36 src = fetchFromGitHub { 37 owner = "openai"; 38 repo = pname; 39 rev = "v${version}"; 40 hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ="; 41 }; 42 43 patches = 44 [ 45 # fix overflow error 46 (fetchpatch { 47 url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch"; 48 hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY="; 49 }) 50 51 # Upstream startded pinning CUDA version and falling back to downloading from Conda 52 # in https://github.com/triton-lang/triton/pull/1574/files#diff-eb8b42d9346d0a5d371facf21a8bfa2d16fb49e213ae7c21f03863accebe0fcfR120-R123 53 ./0000-dont-download-ptxas.patch 54 ] 55 ++ lib.optionals (!cudaSupport) [ 56 # triton wants to get ptxas version even if ptxas is not 57 # used, resulting in ptxas not found error. 58 ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch 59 ]; 60 61 postPatch = 62 let 63 quote = x: ''"${x}"''; 64 subs.ldFlags = 65 let 66 # Bash was getting weird without linting, 67 # but basically upstream contains [cc, ..., "-lcuda", ...] 68 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] 69 old = [ "-lcuda" ]; 70 new = [ 71 "-lcuda" 72 "-L${addDriverRunpath.driverLink}" 73 "-L${cudaPackages.cuda_cudart}/lib/stubs/" 74 ]; 75 in 76 { 77 oldStr = lib.concatMapStringsSep ", " quote old; 78 newStr = lib.concatMapStringsSep ", " quote new; 79 }; 80 in 81 '' 82 # Use our `cmakeFlags` instead and avoid downloading dependencies 83 substituteInPlace python/setup.py \ 84 --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()" 85 86 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS 87 substituteInPlace bin/CMakeLists.txt \ 88 --replace "add_subdirectory(FileCheck)" "" 89 90 # Don't fetch googletest 91 substituteInPlace unittest/CMakeLists.txt \ 92 --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\ 93 --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" 94 95 cat << \EOF >> python/triton/common/build.py 96 def libcuda_dirs(): 97 return [ "${addDriverRunpath.driverLink}/lib" ] 98 EOF 99 '' 100 + lib.optionalString cudaSupport '' 101 # Use our linker flags 102 substituteInPlace python/triton/common/build.py \ 103 --replace '${subs.ldFlags.oldStr}' '${subs.ldFlags.newStr}' 104 ''; 105 106 nativeBuildInputs = [ 107 setuptools 108 # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs: 109 cmake 110 ninja 111 112 # Note for future: 113 # These *probably* should go in depsTargetTarget 114 # ...but we cannot test cross right now anyway 115 # because we only support cudaPackages on x86_64-linux atm 116 lit 117 llvm 118 ]; 119 120 buildInputs = [ 121 gtest 122 libxml2.dev 123 ncurses 124 pybind11 125 zlib 126 ]; 127 128 propagatedBuildInputs = [ 129 filelock 130 # triton uses setuptools at runtime: 131 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652 132 setuptools 133 ]; 134 135 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [ 136 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch. 137 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042 138 "-Wno-stringop-overread" 139 ]; 140 141 # Avoid GLIBCXX mismatch with other cuda-enabled python packages 142 preConfigure = 143 '' 144 # Ensure that the build process uses the requested number of cores 145 export MAX_JOBS="$NIX_BUILD_CORES" 146 147 # Upstream's setup.py tries to write cache somewhere in ~/ 148 export HOME=$(mktemp -d) 149 150 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant 151 echo " 152 [build_ext] 153 base-dir=$PWD" >> python/setup.cfg 154 155 # The rest (including buildPhase) is relative to ./python/ 156 cd python 157 '' 158 + lib.optionalString cudaSupport '' 159 export CC=${cudaPackages.backendStdenv.cc}/bin/cc; 160 export CXX=${cudaPackages.backendStdenv.cc}/bin/c++; 161 162 # Work around download_and_copy_ptxas() 163 mkdir -p $PWD/triton/third_party/cuda/bin 164 ln -s ${ptxas} $PWD/triton/third_party/cuda/bin 165 ''; 166 167 # CMake is run by setup.py instead 168 dontUseCmakeConfigure = true; 169 170 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink 171 postFixup = lib.optionalString cudaSupport '' 172 rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas 173 ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas 174 ''; 175 176 checkInputs = [ cmake ]; # ctest 177 dontUseSetuptoolsCheck = true; 178 179 preCheck = '' 180 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 181 (cd /build/source/python/build/temp* ; ctest) 182 183 # For pytestCheckHook 184 cd test/unit 185 ''; 186 187 # Circular dependency on torch 188 # pythonImportsCheck = [ 189 # "triton" 190 # "triton.language" 191 # ]; 192 193 # Ultimately, torch is our test suite: 194 passthru.tests = { 195 inherit torchWithRocm; 196 # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again, 197 # and pythonImportsCheck is commented back. 198 import-triton = 199 runCommand "import-triton" 200 { nativeBuildInputs = [ (python.withPackages (ps: [ ps.triton ])) ]; } 201 '' 202 python << \EOF 203 import triton 204 import triton.language 205 EOF 206 touch "$out" 207 ''; 208 }; 209 210 pythonRemoveDeps = [ 211 # Circular dependency, cf. https://github.com/openai/triton/issues/1374 212 "torch" 213 214 # CLI tools without dist-info 215 "cmake" 216 "lit" 217 ]; 218 219 meta = with lib; { 220 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 221 homepage = "https://github.com/openai/triton"; 222 platforms = platforms.linux; 223 license = licenses.mit; 224 maintainers = with maintainers; [ 225 SomeoneSerge 226 Madouura 227 ]; 228 }; 229}