nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 324 lines 9.5 kB view raw
1{ 2 lib, 3 stdenv, 4 config, 5 buildPythonPackage, 6 fetchFromGitHub, 7 8 # patches 9 replaceVars, 10 addDriverRunpath, 11 cudaPackages, 12 13 # build-system 14 setuptools, 15 16 # nativeBuildInputs 17 cmake, 18 ninja, 19 lit, 20 llvm, 21 writableTmpDirAsHomeHook, 22 23 # buildInputs 24 gtest, 25 libxml2, 26 ncurses, 27 pybind11, 28 zlib, 29 30 # dependencies 31 filelock, 32 33 # passthru 34 python, 35 pytestCheckHook, 36 torchWithRocm, 37 runCommand, 38 triton, 39 rocmPackages, 40 41 cudaSupport ? config.cudaSupport, 42}: 43 44buildPythonPackage rec { 45 pname = "triton"; 46 version = "3.5.1"; 47 pyproject = true; 48 49 # Remember to bump triton-llvm as well! 50 src = fetchFromGitHub { 51 owner = "triton-lang"; 52 repo = "triton"; 53 tag = "v${version}"; 54 hash = "sha256-dyNRtS1qtU8C/iAf0Udt/1VgtKGSvng1+r2BtvT9RB4="; 55 }; 56 57 patches = [ 58 (replaceVars ./0001-_build-allow-extra-cc-flags.patch { 59 ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib"; 60 }) 61 (replaceVars ./0002-nvidia-driver-short-circuit-before-ldconfig.patch { 62 libcudaStubsDir = 63 if cudaSupport then "${lib.getOutput "stubs" cudaPackages.cuda_cudart}/lib/stubs" else null; 64 }) 65 ] 66 ++ lib.optionals cudaSupport [ 67 (replaceVars ./0003-nvidia-cudart-a-systempath.patch { 68 cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include"; 69 }) 70 (replaceVars ./0004-nvidia-allow-static-ptxas-path.patch { 71 nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ]; 72 }) 73 ]; 74 75 postPatch = 76 # Allow CMake 4 77 # Upstream issue: https://github.com/triton-lang/triton/issues/8245 78 '' 79 substituteInPlace pyproject.toml \ 80 --replace-fail "cmake>=3.20,<4.0" "cmake>=3.20" 81 '' 82 # Avoid downloading dependencies remove any downloads 83 + '' 84 substituteInPlace setup.py \ 85 --replace-fail "[get_json_package_info()]" "[]" \ 86 --replace-fail "[get_llvm_package_info()]" "[]" \ 87 --replace-fail 'yield ("triton.profiler", "third_party/proton/proton")' 'pass' \ 88 --replace-fail "curr_version.group(1) != version" "False" 89 '' 90 # Use our `cmakeFlags` instead and avoid downloading dependencies 91 + '' 92 substituteInPlace setup.py \ 93 --replace-fail \ 94 "cmake_args.extend(thirdparty_cmake_args)" \ 95 "cmake_args.extend(thirdparty_cmake_args + os.environ.get('cmakeFlags', \"\").split())" 96 '' 97 # Don't fetch googletest 98 + '' 99 substituteInPlace cmake/AddTritonUnitTest.cmake \ 100 --replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\ 101 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)" 102 ''; 103 104 build-system = [ setuptools ]; 105 106 nativeBuildInputs = [ 107 cmake 108 ninja 109 110 # Note for future: 111 # These *probably* should go in depsTargetTarget 112 # ...but we cannot test cross right now anyway 113 # because we only support cudaPackages on x86_64-linux atm 114 lit 115 llvm 116 117 # Upstream's setup.py tries to write cache somewhere in ~/ 118 writableTmpDirAsHomeHook 119 ]; 120 121 cmakeFlags = [ 122 (lib.cmakeFeature "LLVM_SYSPATH" "${llvm}") 123 124 # `find_package` is called with `NO_DEFAULT_PATH` 125 # https://cmake.org/cmake/help/latest/command/find_package.html 126 # https://github.com/triton-lang/triton/blob/c3c476f357f1e9768ea4e45aa5c17528449ab9ef/third_party/amd/CMakeLists.txt#L6 127 (lib.cmakeFeature "LLD_DIR" "${lib.getLib llvm}") 128 ]; 129 130 buildInputs = [ 131 gtest 132 libxml2.dev 133 ncurses 134 pybind11 135 zlib 136 ]; 137 138 dependencies = [ 139 filelock 140 # triton uses setuptools at runtime: 141 # https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652 142 setuptools 143 ]; 144 145 NIX_CFLAGS_COMPILE = lib.optionals cudaSupport [ 146 # Pybind11 started generating strange errors since python 3.12. Observed only in the CUDA branch. 147 # https://gist.github.com/SomeoneSerge/7d390b2b1313957c378e99ed57168219#file-gistfile0-txt-L1042 148 "-Wno-stringop-overread" 149 ]; 150 151 preConfigure = 152 # Ensure that the build process uses the requested number of cores 153 '' 154 export MAX_JOBS="$NIX_BUILD_CORES" 155 ''; 156 157 env = { 158 TRITON_BUILD_PROTON = "OFF"; 159 TRITON_OFFLINE_BUILD = true; 160 } 161 // lib.optionalAttrs cudaSupport { 162 CC = lib.getExe' cudaPackages.backendStdenv.cc "cc"; 163 CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++"; 164 165 # TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change) 166 TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 167 TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump"; 168 TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm"; 169 TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc; 170 TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart; 171 TRITON_CUPTI_PATH = cudaPackages.cuda_cupti; 172 }; 173 174 pythonRemoveDeps = [ 175 # Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374 176 "torch" 177 178 # CLI tools without dist-info 179 "cmake" 180 "lit" 181 ]; 182 183 # CMake is run by setup.py instead 184 dontUseCmakeConfigure = true; 185 186 nativeCheckInputs = [ cmake ]; 187 preCheck = '' 188 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 189 (cd ./build/temp* ; ctest) 190 ''; 191 192 pythonImportsCheck = [ 193 "triton" 194 "triton.language" 195 ]; 196 197 passthru = { 198 gpuCheck = stdenv.mkDerivation { 199 pname = "triton-pytest"; 200 inherit (triton) version src; 201 202 requiredSystemFeatures = [ "cuda" ]; 203 204 nativeBuildInputs = [ 205 (python.withPackages (ps: [ 206 ps.scipy 207 ps.torchWithCuda 208 ps.triton-cuda 209 ])) 210 ]; 211 212 dontBuild = true; 213 nativeCheckInputs = [ 214 pytestCheckHook 215 writableTmpDirAsHomeHook 216 ]; 217 218 doCheck = true; 219 220 preCheck = '' 221 cd python/test/unit 222 ''; 223 checkPhase = "pytestCheckPhase"; 224 225 installPhase = "touch $out"; 226 }; 227 228 tests = { 229 # Ultimately, torch is our test suite: 230 inherit torchWithRocm; 231 232 # Test that _get_path_to_hip_runtime_dylib works when ROCm is available at runtime 233 rocm-libamdhip64-path = 234 runCommand "triton-rocm-libamdhip64-path-test" 235 { 236 buildInputs = [ 237 triton 238 python 239 rocmPackages.clr 240 ]; 241 } 242 '' 243 python -c " 244 import os 245 import triton 246 path = triton.backends.amd.driver._get_path_to_hip_runtime_dylib() 247 print(f'libamdhip64 path: {path}') 248 assert os.path.exists(path) 249 " && touch $out 250 ''; 251 252 # Test as `nix run -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda` 253 # or, using `programs.nix-required-mounts`, as `nix build -f "<nixpkgs>" python3Packages.triton.tests.axpy-cuda.gpuCheck` 254 axpy-cuda = 255 cudaPackages.writeGpuTestPython 256 { 257 libraries = ps: [ 258 ps.triton 259 ps.torch-no-triton 260 ]; 261 262 gpuCheckArgs.nativeBuildInputs = [ 263 # PermissionError: [Errno 13] Permission denied: '/homeless-shelter' 264 writableTmpDirAsHomeHook 265 ]; 266 } 267 '' 268 # Adopted from Philippe Tillet https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html 269 270 import triton 271 import triton.language as tl 272 import torch 273 import os 274 275 @triton.jit 276 def axpy_kernel(n, a: tl.constexpr, x_ptr, y_ptr, out, BLOCK_SIZE: tl.constexpr): 277 pid = tl.program_id(axis=0) 278 block_start = pid * BLOCK_SIZE 279 offsets = block_start + tl.arange(0, BLOCK_SIZE) 280 mask = offsets < n 281 x = tl.load(x_ptr + offsets, mask=mask) 282 y = tl.load(y_ptr + offsets, mask=mask) 283 output = a * x + y 284 tl.store(out + offsets, output, mask=mask) 285 286 def axpy(a, x, y): 287 output = torch.empty_like(x) 288 assert x.is_cuda and y.is_cuda and output.is_cuda 289 n_elements = output.numel() 290 291 def grid(meta): 292 return (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) 293 294 axpy_kernel[grid](n_elements, a, x, y, output, BLOCK_SIZE=1024) 295 return output 296 297 if __name__ == "__main__": 298 if os.environ.get("HOME", None) == "/homeless-shelter": 299 os.environ["HOME"] = os.environ.get("TMPDIR", "/tmp") 300 if "CC" not in os.environ: 301 os.environ["CC"] = "${lib.getExe' cudaPackages.backendStdenv.cc "cc"}" 302 torch.manual_seed(0) 303 size = 12345 304 x = torch.rand(size, device='cuda') 305 y = torch.rand(size, device='cuda') 306 output_torch = 3.14 * x + y 307 output_triton = axpy(3.14, x, y) 308 assert output_torch.sub(output_triton).abs().max().item() < 1e-6 309 print("Triton axpy: OK") 310 ''; 311 }; 312 }; 313 314 meta = { 315 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 316 homepage = "https://github.com/triton-lang/triton"; 317 platforms = lib.platforms.linux; 318 license = lib.licenses.mit; 319 maintainers = with lib.maintainers; [ 320 SomeoneSerge 321 derdennisop 322 ]; 323 }; 324}