nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1{
2 triton-no-cuda,
3 rocmPackages,
4 fetchFromGitHub,
5}:
6(triton-no-cuda.override (_old: {
7 inherit rocmPackages;
8 rocmSupport = true;
9 stdenv = rocmPackages.llvm.rocmClangStdenv;
10 llvm = rocmPackages.triton-llvm;
11})).overridePythonAttrs
12 (old: {
13 doCheck = false;
14 stdenv = rocmPackages.llvm.rocmClangStdenv;
15 version = "3.2.0";
16 src = fetchFromGitHub {
17 owner = "triton-lang";
18 repo = "triton";
19 rev = "9641643da6c52000c807b5eeed05edaec4402a67"; # "release/3.2.x";
20 hash = "sha256-V1lpARwOLn28ZHfjiWR/JJWGw3MB34c+gz6Tq1GOVfo=";
21 };
22 buildInputs = old.buildInputs ++ [
23 rocmPackages.clr
24 ];
25 dontStrip = true;
26 env = old.env // {
27 CXXFLAGS = "-O3 -I${rocmPackages.clr}/include -I/build/source/third_party/triton/third_party/nvidia/backend/include";
28 TRITON_OFFLINE_BUILD = 1;
29 };
30 patches = [ ];
31 postPatch = ''
32 # Remove nvidia backend so we don't depend on unfree nvidia headers
33 # when we only want to target ROCm
34 rm -rf third_party/nvidia
35 substituteInPlace CMakeLists.txt \
36 --replace-fail "add_subdirectory(test)" ""
37 sed -i '/nvidia\|NVGPU\|registerConvertTritonGPUToLLVMPass\|mlir::test::/Id' bin/RegisterTritonDialects.h
38 sed -i '/TritonTestAnalysis/Id' bin/CMakeLists.txt
39 substituteInPlace python/setup.py \
40 --replace-fail 'backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]' \
41 'backends = [*BackendInstaller.copy(["amd"]), *BackendInstaller.copy_externals()]'
42 find . -type f -exec sed -i 's|[<]cupti.h[>]|"cupti.h"|g' {} +
43 find . -type f -exec sed -i 's|[<]cuda.h[>]|"cuda.h"|g' {} +
44 # remove any downloads
45 substituteInPlace python/setup.py \
46 --replace-fail "[get_json_package_info()]" "[]"\
47 --replace-fail "[get_llvm_package_info()]" "[]"\
48 --replace-fail "curr_version != version" "False"
49 # Don't fetch googletest
50 substituteInPlace cmake/AddTritonUnitTest.cmake \
51 --replace-fail 'include(''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)' "" \
52 --replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
53 substituteInPlace third_party/amd/backend/compiler.py \
54 --replace-fail '"/opt/rocm/llvm/bin/ld.lld"' "os.environ['ROCM_PATH']"' + "/llvm/bin/ld.lld"'
55 '';
56 })