at 25.11-pre 8.6 kB view raw
1{ 2 lib, 3 addDriverRunpath, 4 buildPythonPackage, 5 cmake, 6 config, 7 cudaPackages, 8 fetchFromGitHub, 9 filelock, 10 gtest, 11 libxml2, 12 lit, 13 llvm, 14 ncurses, 15 ninja, 16 pybind11, 17 python, 18 pytestCheckHook, 19 stdenv, 20 replaceVars, 21 setuptools, 22 torchWithRocm, 23 zlib, 24 cudaSupport ? config.cudaSupport, 25 rocmSupport ? config.rocmSupport, 26 rocmPackages, 27 triton, 28}: 29 30buildPythonPackage { 31 pname = "triton"; 32 version = "3.1.0"; 33 pyproject = true; 34 35 src = fetchFromGitHub { 36 owner = "triton-lang"; 37 repo = "triton"; 38 # latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/ 39 rev = "cf34004b8a67d290a962da166f5aa2fc66751326"; 40 hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk="; 41 }; 42 43 patches = 44 [ 45 ./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch 46 (replaceVars ./0001-_build-allow-extra-cc-flags.patch { 47 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib"; 48 }) 49 (replaceVars ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch { 50 libhipDir = if rocmSupport then "${lib.getLib rocmPackages.clr}/lib" else null; 51 libcudaStubsDir = 52 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null; 53 }) 54 ] 55 ++ lib.optionals cudaSupport [ 56 (replaceVars ./0003-nvidia-cudart-a-systempath.patch { 57 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include"; 58 }) 59 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch { 60 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ]; 61 }) 62 ]; 63 64 postPatch = '' 65 # Use our `cmakeFlags` instead and avoid downloading dependencies 66 # remove any downloads 67 substituteInPlace python/setup.py \ 68 --replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\ 69 --replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\ 70 --replace-fail 'packages += ["triton/profiler"]' ""\ 71 --replace-fail "curr_version != version" "False" 72 73 # Don't fetch googletest 74 substituteInPlace unittest/CMakeLists.txt \ 75 --replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\ 76 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)" 77 78 # Patch the source code to make sure it doesn't specify a non-existent PTXAS version. 79 # CUDA 12.6 (the current default/max) tops out at PTXAS version 8.5. 80 # NOTE: This is fixed in `master`: 81 # https://github.com/triton-lang/triton/commit/f48dbc1b106c93144c198fbf3c4f30b2aab9d242 82 substituteInPlace "$NIX_BUILD_TOP/$sourceRoot/third_party/nvidia/backend/compiler.py" \ 83 --replace-fail \ 84 'return 80 + minor' \ 85 'return 80 + min(minor, 5)' 86 ''; 87 88 build-system = [ setuptools ]; 89 90 nativeBuildInputs = [ 91 cmake 92 ninja 93 94 # Note for future: 95 # These *probably* should go in depsTargetTarget 96 # ...but we cannot test cross right now anyway 97 # because we only support cudaPackages on x86_64-linux atm 98 lit 99 llvm 100 ]; 101 102 buildInputs = [ 103 gtest 104 libxml2.dev 105 ncurses 106 pybind11 107 zlib 108 ]; 109 110 dependencies = [ 111 filelock 112 # triton uses setuptools at runtime: 113 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652 114 setuptools 115 ]; 116 117 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [ 118 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch. 119 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042 120 "-Wno-stringop-overread" 121 ]; 122 123 # Avoid GLIBCXX mismatch with other cuda-enabled python packages 124 preConfigure = '' 125 # Ensure that the build process uses the requested number of cores 126 export MAX_JOBS="$NIX_BUILD_CORES" 127 128 # Upstream's setup.py tries to write cache somewhere in ~/ 129 export HOME=$(mktemp -d) 130 131 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant 132 echo " 133 [build_ext] 134 base-dir=$PWD" >> python/setup.cfg 135 136 # The rest (including buildPhase) is relative to ./python/ 137 cd python 138 ''; 139 140 env = 141 { 142 TRITON_BUILD_PROTON = "OFF"; 143 TRITON_OFFLINE_BUILD = true; 144 } 145 // lib.optionalAttrs cudaSupport { 146 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc"; 147 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++"; 148 149 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change) 150 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 151 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump"; 152 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm"; 153 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc; 154 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart; 155 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti; 156 }; 157 158 pythonRemoveDeps = [ 159 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374 160 "torch" 161 162 # CLI tools without dist-info 163 "cmake" 164 "lit" 165 ]; 166 167 # CMake is run by setup.py instead 168 dontUseCmakeConfigure = true; 169 170 nativeCheckInputs = [ cmake ]; 171 preCheck = '' 172 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 173 (cd ./build/temp* ; ctest) 174 ''; 175 176 pythonImportsCheck = [ 177 "triton" 178 "triton.language" 179 ]; 180 181 passthru.gpuCheck = stdenv.mkDerivation { 182 pname = "triton-pytest"; 183 inherit (triton) version src; 184 185 requiredSystemFeatures = [ "cuda" ]; 186 187 nativeBuildInputs = [ 188 (python.withPackages (ps: [ 189 ps.scipy 190 ps.torchWithCuda 191 ps.triton-cuda 192 ])) 193 ]; 194 195 dontBuild = true; 196 nativeCheckInputs = [ pytestCheckHook ]; 197 198 doCheck = true; 199 200 preCheck = '' 201 cd python/test/unit 202 export HOME=$TMPDIR 203 ''; 204 checkPhase = "pytestCheckPhase"; 205 206 installPhase = "touch $out"; 207 }; 208 209 passthru.tests = { 210 # Ultimately, torch is our test suite: 211 inherit torchWithRocm; 212 213 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda` 214 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck` 215 axpy-cuda = 216 cudaPackages.writeGpuTestPython 217 { 218 libraries = ps: [ 219 ps.triton 220 ps.torch-no-triton 221 ]; 222 } 223 '' 224 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html 225 226 import triton 227 import triton.language as tl 228 import torch 229 import os 230 231 @triton.jit 232 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr): 233 pid = tl.program_id(axis=0) 234 block_start = pid * BLOCK_SIZE 235 offsets = block_start + tl.arange(0, BLOCK_SIZE) 236 mask = offsets < n 237 x = tl.load(x_ptr + offsets, mask=mask) 238 y = tl.load(y_ptr + offsets, mask=mask) 239 output = a * x + y 240 tl.store(out + offsets, output, mask=mask) 241 242 def axpy(a, x, y): 243 output = torch.empty_like(x) 244 assert x.is_cuda and y.is_cuda and output.is_cuda 245 n_elements = output.numel() 246 247 def grid(meta): 248 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 249 250 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024) 251 return output 252 253 if __name__ == "__main__": 254 if os.environ.get("HOME", None) == "/homeless-shelter": 255 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp") 256 if "CC" not in os.environ: 257 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}" 258 torch.manual_seed(0) 259 size = 12345 260 x = torch.rand(size, device='cuda') 261 y = torch.rand(size, device='cuda') 262 output_torch = 3.14 * x + y 263 output_triton = axpy(3.14, x, y) 264 assert output_torch.sub(output_triton).abs().max().item() < 1e-6 265 print("Triton axpy: OK") 266 ''; 267 }; 268 269 meta = { 270 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 271 homepage = "https://github.com/triton-lang/triton"; 272 platforms = lib.platforms.linux; 273 license = lib.licenses.mit; 274 maintainers = with lib.maintainers; [ 275 SomeoneSerge 276 Madouura 277 derdennisop 278 ]; 279 }; 280}