Clone of https://github.com/NixOS/nixpkgs.git (to stress-test knotserver)
at r-updates 275 lines 8.0 kB view raw
1{ 2 lib, 3 addDriverRunpath, 4 buildPythonPackage, 5 cmake, 6 config, 7 cudaPackages, 8 fetchFromGitHub, 9 filelock, 10 gtest, 11 libxml2, 12 lit, 13 llvm, 14 ncurses, 15 ninja, 16 pybind11, 17 python, 18 pytestCheckHook, 19 writableTmpDirAsHomeHook, 20 stdenv, 21 replaceVars, 22 setuptools, 23 torchWithRocm, 24 zlib, 25 cudaSupport ? config.cudaSupport, 26 rocmSupport ? config.rocmSupport, 27 rocmPackages, 28 triton, 29}: 30 31buildPythonPackage rec { 32 pname = "triton"; 33 version = "3.4.0"; 34 pyproject = true; 35 36 # Remember to bump triton-llvm as well! 37 src = fetchFromGitHub { 38 owner = "triton-lang"; 39 repo = "triton"; 40 tag = "v${version}"; 41 hash = "sha256-78s9ke6UV7Tnx3yCr0QZcVDqQELR4XoGgJY7olNJmjk="; 42 }; 43 44 patches = [ 45 (replaceVars ./0001-_build-allow-extra-cc-flags.patch { 46 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib"; 47 }) 48 (replaceVars ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch { 49 libhipDir = if rocmSupport then "${lib.getLib rocmPackages.clr}/lib" else null; 50 libcudaStubsDir = 51 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null; 52 }) 53 ] 54 ++ lib.optionals cudaSupport [ 55 (replaceVars ./0003-nvidia-cudart-a-systempath.patch { 56 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include"; 57 }) 58 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch { 59 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ]; 60 }) 61 ]; 62 63 postPatch = 64 # Avoid downloading dependencies remove any downloads 65 '' 66 substituteInPlace setup.py \ 67 --replace-fail "[get_json_package_info()]" "[]" \ 68 --replace-fail "[get_llvm_package_info()]" "[]" \ 69 --replace-fail 'yield ("triton.profiler", "third_party/proton/proton")' 'pass' \ 70 --replace-fail "curr_version.group(1) != version" "False" 71 '' 72 # Use our `cmakeFlags` instead and avoid downloading dependencies 73 + '' 74 substituteInPlace setup.py \ 75 --replace-fail \ 76 "cmake_args.extend(thirdparty_cmake_args)" \ 77 "cmake_args.extend(thirdparty_cmake_args + os.environ.get('cmakeFlags', \"\").split())" 78 '' 79 # Don't fetch googletest 80 + '' 81 substituteInPlace cmake/AddTritonUnitTest.cmake \ 82 --replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\ 83 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)" 84 ''; 85 86 build-system = [ setuptools ]; 87 88 nativeBuildInputs = [ 89 cmake 90 ninja 91 92 # Note for future: 93 # These *probably* should go in depsTargetTarget 94 # ...but we cannot test cross right now anyway 95 # because we only support cudaPackages on x86_64-linux atm 96 lit 97 llvm 98 99 # Upstream's setup.py tries to write cache somewhere in ~/ 100 writableTmpDirAsHomeHook 101 ]; 102 103 cmakeFlags = [ 104 (lib.cmakeFeature "LLVM_SYSPATH" "${llvm}") 105 ]; 106 107 buildInputs = [ 108 gtest 109 libxml2.dev 110 ncurses 111 pybind11 112 zlib 113 ]; 114 115 dependencies = [ 116 filelock 117 # triton uses setuptools at runtime: 118 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652 119 setuptools 120 ]; 121 122 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [ 123 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch. 124 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042 125 "-Wno-stringop-overread" 126 ]; 127 128 preConfigure = 129 # Ensure that the build process uses the requested number of cores 130 '' 131 export MAX_JOBS="$NIX_BUILD_CORES" 132 ''; 133 134 env = { 135 TRITON_BUILD_PROTON = "OFF"; 136 TRITON_OFFLINE_BUILD = true; 137 } 138 // lib.optionalAttrs cudaSupport { 139 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc"; 140 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++"; 141 142 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change) 143 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 144 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump"; 145 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm"; 146 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc; 147 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart; 148 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti; 149 }; 150 151 pythonRemoveDeps = [ 152 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374 153 "torch" 154 155 # CLI tools without dist-info 156 "cmake" 157 "lit" 158 ]; 159 160 # CMake is run by setup.py instead 161 dontUseCmakeConfigure = true; 162 163 nativeCheckInputs = [ cmake ]; 164 preCheck = '' 165 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 166 (cd ./build/temp* ; ctest) 167 ''; 168 169 pythonImportsCheck = [ 170 "triton" 171 "triton.language" 172 ]; 173 174 passthru.gpuCheck = stdenv.mkDerivation { 175 pname = "triton-pytest"; 176 inherit (triton) version src; 177 178 requiredSystemFeatures = [ "cuda" ]; 179 180 nativeBuildInputs = [ 181 (python.withPackages (ps: [ 182 ps.scipy 183 ps.torchWithCuda 184 ps.triton-cuda 185 ])) 186 ]; 187 188 dontBuild = true; 189 nativeCheckInputs = [ 190 pytestCheckHook 191 writableTmpDirAsHomeHook 192 ]; 193 194 doCheck = true; 195 196 preCheck = '' 197 cd python/test/unit 198 ''; 199 checkPhase = "pytestCheckPhase"; 200 201 installPhase = "touch $out"; 202 }; 203 204 passthru.tests = { 205 # Ultimately, torch is our test suite: 206 inherit torchWithRocm; 207 208 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda` 209 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck` 210 axpy-cuda = 211 cudaPackages.writeGpuTestPython 212 { 213 libraries = ps: [ 214 ps.triton 215 ps.torch-no-triton 216 ]; 217 } 218 '' 219 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html 220 221 import triton 222 import triton.language as tl 223 import torch 224 import os 225 226 @triton.jit 227 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr): 228 pid = tl.program_id(axis=0) 229 block_start = pid * BLOCK_SIZE 230 offsets = block_start + tl.arange(0, BLOCK_SIZE) 231 mask = offsets < n 232 x = tl.load(x_ptr + offsets, mask=mask) 233 y = tl.load(y_ptr + offsets, mask=mask) 234 output = a * x + y 235 tl.store(out + offsets, output, mask=mask) 236 237 def axpy(a, x, y): 238 output = torch.empty_like(x) 239 assert x.is_cuda and y.is_cuda and output.is_cuda 240 n_elements = output.numel() 241 242 def grid(meta): 243 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 244 245 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024) 246 return output 247 248 if __name__ == "__main__": 249 if os.environ.get("HOME", None) == "/homeless-shelter": 250 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp") 251 if "CC" not in os.environ: 252 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}" 253 torch.manual_seed(0) 254 size = 12345 255 x = torch.rand(size, device='cuda') 256 y = torch.rand(size, device='cuda') 257 output_torch = 3.14 * x + y 258 output_triton = axpy(3.14, x, y) 259 assert output_torch.sub(output_triton).abs().max().item() < 1e-6 260 print("Triton axpy: OK") 261 ''; 262 }; 263 264 meta = { 265 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 266 homepage = "https://github.com/triton-lang/triton"; 267 platforms = lib.platforms.linux; 268 license = lib.licenses.mit; 269 maintainers = with lib.maintainers; [ 270 SomeoneSerge 271 Madouura 272 derdennisop 273 ]; 274 }; 275}