at 23.11-beta 184 lines 5.6 kB view raw
1{ lib 2, config 3, buildPythonPackage 4, fetchFromGitHub 5, addOpenGLRunpath 6, pytestCheckHook 7, pythonRelaxDepsHook 8, pkgsTargetTarget 9, cmake 10, ninja 11, pybind11 12, gtest 13, zlib 14, ncurses 15, libxml2 16, lit 17, llvm 18, filelock 19, torchWithRocm 20, python 21, cudaPackages 22, cudaSupport ? config.cudaSupport 23}: 24 25let 26 # A time may come we'll want to be cross-friendly 27 # 28 # Short explanation: we need pkgsTargetTarget, because we use string 29 # interpolation instead of buildInputs. 30 # 31 # Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's 32 # ptxas compiler. We're not running this ptxas on the build machine, but on 33 # the user's machine, i.e. our Target platform. The second "Target" in 34 # pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to 35 # be executed on the GPU. 36 # Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra 37 ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py) 38in 39buildPythonPackage rec { 40 pname = "triton"; 41 version = "2.0.0"; 42 format = "setuptools"; 43 44 src = fetchFromGitHub { 45 owner = "openai"; 46 repo = pname; 47 rev = "v${version}"; 48 hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU="; 49 }; 50 51 patches = [ 52 # TODO: there have been commits upstream aimed at removing the "torch" 53 # circular dependency, but the patches fail to apply on the release 54 # revision. Keeping the link for future reference 55 # Also cf. https://github.com/openai/triton/issues/1374 56 57 # (fetchpatch { 58 # url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch"; 59 # hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig="; 60 # }) 61 ] ++ lib.optionals (!cudaSupport) [ 62 ./0000-dont-download-ptxas.patch 63 ]; 64 65 nativeBuildInputs = [ 66 pythonRelaxDepsHook 67 # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs: 68 cmake 69 ninja 70 71 # Note for future: 72 # These *probably* should go in depsTargetTarget 73 # ...but we cannot test cross right now anyway 74 # because we only support cudaPackages on x86_64-linux atm 75 lit 76 llvm 77 ]; 78 79 buildInputs = [ 80 gtest 81 libxml2.dev 82 ncurses 83 pybind11 84 zlib 85 ]; 86 87 propagatedBuildInputs = [ filelock ]; 88 89 postPatch = let 90 # Bash was getting weird without linting, 91 # but basically upstream contains [cc, ..., "-lcuda", ...] 92 # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] 93 old = [ "-lcuda" ]; 94 new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ]; 95 96 quote = x: ''"${x}"''; 97 oldStr = lib.concatMapStringsSep ", " quote old; 98 newStr = lib.concatMapStringsSep ", " quote new; 99 in '' 100 # Use our `cmakeFlags` instead and avoid downloading dependencies 101 substituteInPlace python/setup.py \ 102 --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()" 103 104 # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS 105 substituteInPlace bin/CMakeLists.txt \ 106 --replace "add_subdirectory(FileCheck)" "" 107 108 # Don't fetch googletest 109 substituteInPlace unittest/CMakeLists.txt \ 110 --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\ 111 --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" 112 '' + lib.optionalString cudaSupport '' 113 # Use our linker flags 114 substituteInPlace python/triton/compiler.py \ 115 --replace '${oldStr}' '${newStr}' 116 ''; 117 118 # Avoid GLIBCXX mismatch with other cuda-enabled python packages 119 preConfigure = '' 120 # Upstream's setup.py tries to write cache somewhere in ~/ 121 export HOME=$(mktemp -d) 122 123 # Upstream's github actions patch setup.cfg to write base-dir. May be redundant 124 echo " 125 [build_ext] 126 base-dir=$PWD" >> python/setup.cfg 127 128 # The rest (including buildPhase) is relative to ./python/ 129 cd python 130 '' + lib.optionalString cudaSupport '' 131 export CC=${cudaPackages.backendStdenv.cc}/bin/cc; 132 export CXX=${cudaPackages.backendStdenv.cc}/bin/c++; 133 134 # Work around download_and_copy_ptxas() 135 mkdir -p $PWD/triton/third_party/cuda/bin 136 ln -s ${ptxas} $PWD/triton/third_party/cuda/bin 137 ''; 138 139 # CMake is run by setup.py instead 140 dontUseCmakeConfigure = true; 141 142 # Setuptools (?) strips runpath and +x flags. Let's just restore the symlink 143 postFixup = lib.optionalString cudaSupport '' 144 rm -f $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas 145 ln -s ${ptxas} $out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas 146 ''; 147 148 checkInputs = [ cmake ]; # ctest 149 dontUseSetuptoolsCheck = true; 150 151 preCheck = '' 152 # build/temp* refers to build_ext.build_temp (looked up in the build logs) 153 (cd /build/source/python/build/temp* ; ctest) 154 155 # For pytestCheckHook 156 cd test/unit 157 ''; 158 159 # Circular dependency on torch 160 # pythonImportsCheck = [ 161 # "triton" 162 # "triton.language" 163 # ]; 164 165 # Ultimately, torch is our test suite: 166 passthru.tests = { inherit torchWithRocm; }; 167 168 pythonRemoveDeps = [ 169 # Circular dependency, cf. https://github.com/openai/triton/issues/1374 170 "torch" 171 172 # CLI tools without dist-info 173 "cmake" 174 "lit" 175 ]; 176 177 meta = with lib; { 178 description = "Language and compiler for writing highly efficient custom Deep-Learning primitives"; 179 homepage = "https://github.com/openai/triton"; 180 platforms = lib.platforms.unix; 181 license = licenses.mit; 182 maintainers = with maintainers; [ SomeoneSerge Madouura ]; 183 }; 184}