Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at gcc-offload 279 lines 8.3 kB view raw
1{ 2 lib, 3 addDriverRunpath, 4 buildPythonPackage, 5 cmake, 6 config, 7 cudaPackages, 8 fetchFromGitHub, 9 filelock, 10 gtest, 11 libxml2, 12 lit, 13 llvm, 14 ncurses, 15 ninja, 16 pybind11, 17 python, 18 pytestCheckHook, 19 stdenv, 20 substituteAll, 21 setuptools, 22 torchWithRocm, 23 zlib, 24 cudaSupport ? config.cudaSupport, 25 rocmSupport ? config.rocmSupport, 26 rocmPackages, 27 triton, 28}: 29 30buildPythonPackage { 31 pname = "triton"; 32 version = "3.1.0"; 33 pyproject = true; 34 35 src = fetchFromGitHub { 36 owner = "triton-lang"; 37 repo = "triton"; 38 # latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/ 39 rev = "cf34004b8a67d290a962da166f5aa2fc66751326"; 40 hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk="; 41 }; 42 43 patches = 44 [ 45 ./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch 46 (substituteAll { 47 src = ./0001-_build-allow-extra-cc-flags.patch; 48 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib"; 49 }) 50 (substituteAll ( 51 { 52 src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch; 53 } 54 // lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; } 55 // lib.optionalAttrs cudaSupport { 56 libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs"; 57 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib"; 58 } 59 )) 60 ] 61 ++ lib.optionals cudaSupport [ 62 (substituteAll { 63 src = ./0003-nvidia-cudart-a-systempath.patch; 64 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include"; 65 }) 66 (substituteAll { 67 src = ./0004-nvidia-allow-static-ptxas-path.patch; 68 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ]; 69 }) 70 ]; 71 72 postPatch = '' 73 # Use our `cmakeFlags` instead and avoid downloading dependencies 74 # remove any downloads 75 substituteInPlace python/setup.py \ 76 --replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\ 77 --replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\ 78 --replace-fail 'packages += ["triton/profiler"]' ""\ 79 --replace-fail "curr_version != version" "False" 80 81 # Don't fetch googletest 82 substituteInPlace unittest/CMakeLists.txt \ 83 --replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\ 84 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)" 85 ''; 86 87 build-system = [ setuptools ]; 88 89 nativeBuildInputs = [ 90 cmake 91 ninja 92 93 # Note for future: 94 # These *probably* should go in depsTargetTarget 95 # ...but we cannot test cross right now anyway 96 # because we only support cudaPackages on x86_64-linux atm 97 lit 98 llvm 99 ]; 100 101 buildInputs = [ 102 gtest 103 libxml2.dev 104 ncurses 105 pybind11 106 zlib 107 ]; 108 109 dependencies = [ 110 filelock 111 # triton uses setuptools at runtime: 112 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652 113 setuptools 114 ]; 115 116 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [ 117 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch. 118 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042 119 "-Wno-stringop-overread" 120 ]; 121 122 # Avoid GLIBCXX mismatch with other cuda-enabled python packages 123 preConfigure = '' 124 # Ensure that the build process uses the requested number of cores 125 export MAX_JOBS="$NIX_BUILD_CORES" 126 127 # Upstream's setup.py tries to write cache somewhere in ~/ 128 export HOME=$(mktemp -d) 129 130 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant 131 echo " 132 [build_ext] 133 base-dir=$PWD" >> python/setup.cfg 134 135 # The rest (including buildPhase) is relative to ./python/ 136 cd python 137 ''; 138 139 env = 140 { 141 TRITON_BUILD_PROTON = "OFF"; 142 TRITON_OFFLINE_BUILD = true; 143 } 144 // lib.optionalAttrs cudaSupport { 145 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc"; 146 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++"; 147 148 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change) 149 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 150 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump"; 151 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm"; 152 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc; 153 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart; 154 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti; 155 }; 156 157 pythonRemoveDeps = [ 158 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374 159 "torch" 160 161 # CLI tools without dist-info 162 "cmake" 163 "lit" 164 ]; 165 166 # CMake is run by setup.py instead 167 dontUseCmakeConfigure = true; 168 169 nativeCheckInputs = [ cmake ]; 170 preCheck = '' 171 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 172 (cd ./build/temp* ; ctest) 173 ''; 174 175 pythonImportsCheck = [ 176 "triton" 177 "triton.language" 178 ]; 179 180 passthru.gpuCheck = stdenv.mkDerivation { 181 pname = "triton-pytest"; 182 inherit (triton) version src; 183 184 requiredSystemFeatures = [ "cuda" ]; 185 186 nativeBuildInputs = [ 187 (python.withPackages (ps: [ 188 ps.scipy 189 ps.torchWithCuda 190 ps.triton-cuda 191 ])) 192 ]; 193 194 dontBuild = true; 195 nativeCheckInputs = [ pytestCheckHook ]; 196 197 doCheck = true; 198 199 preCheck = '' 200 cd python/test/unit 201 export HOME=$TMPDIR 202 ''; 203 checkPhase = "pytestCheckPhase"; 204 205 installPhase = "touch $out"; 206 }; 207 208 passthru.tests = { 209 # Ultimately, torch is our test suite: 210 inherit torchWithRocm; 211 212 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda` 213 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck` 214 axpy-cuda = 215 cudaPackages.writeGpuTestPython 216 { 217 libraries = ps: [ 218 ps.triton 219 ps.torch-no-triton 220 ]; 221 } 222 '' 223 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html 224 225 import triton 226 import triton.language as tl 227 import torch 228 import os 229 230 @triton.jit 231 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr): 232 pid = tl.program_id(axis=0) 233 block_start = pid * BLOCK_SIZE 234 offsets = block_start + tl.arange(0, BLOCK_SIZE) 235 mask = offsets < n 236 x = tl.load(x_ptr + offsets, mask=mask) 237 y = tl.load(y_ptr + offsets, mask=mask) 238 output = a * x + y 239 tl.store(out + offsets, output, mask=mask) 240 241 def axpy(a, x, y): 242 output = torch.empty_like(x) 243 assert x.is_cuda and y.is_cuda and output.is_cuda 244 n_elements = output.numel() 245 246 def grid(meta): 247 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 248 249 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024) 250 return output 251 252 if __name__ == "__main__": 253 if os.environ.get("HOME", None) == "/homeless-shelter": 254 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp") 255 if "CC" not in os.environ: 256 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}" 257 torch.manual_seed(0) 258 size = 12345 259 x = torch.rand(size, device='cuda') 260 y = torch.rand(size, device='cuda') 261 output_torch = 3.14 * x + y 262 output_triton = axpy(3.14, x, y) 263 assert output_torch.sub(output_triton).abs().max().item() < 1e-6 264 print("Triton axpy: OK") 265 ''; 266 }; 267 268 meta = { 269 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 270 homepage = "https://github.com/triton-lang/triton"; 271 platforms = lib.platforms.linux; 272 license = lib.licenses.mit; 273 maintainers = with lib.maintainers; [ 274 SomeoneSerge 275 Madouura 276 derdennisop 277 ]; 278 }; 279}