{ lib, stdenv, fetchFromGitHub, fetchpatch, cmake, rocm-cmake, clr, python3, ninja, xz, writableTmpDirAsHomeHook, pkg-config, gpuTargets ? clr.localGpuTargets or clr.gpuTargets, # for passthru.tests aotriton, hello, }: let supportedTargets = lib.lists.intersectLists [ # aotriton GPU support list: # https://github.com/ROCm/aotriton/blob/main/v2python/gpu_targets.py "gfx90a" "gfx942" "gfx950" # some gfx1100 kernels fail with error: branch size exceeds simm16 # but build proceeds and those ops will fallback so it's ok "gfx1100" "gfx1151" "gfx1150" "gfx1201" "gfx1200" ] gpuTargets; anySupportedTargets = supportedTargets != [ ]; # Pick a single arbitrary target to speed up shim build when we can't support our target supportedTargets' = if anySupportedTargets then lib.concatStringsSep ";" supportedTargets else "gfx1200"; in stdenv.mkDerivation (finalAttrs: { pname = "aotriton${lib.optionalString (!anySupportedTargets) "-shim"}"; version = "0.11.1b"; src = fetchFromGitHub { owner = "ROCm"; repo = "aotriton"; tag = finalAttrs.version; hash = "sha256-F7JjyS+6gMdCpOFLldTsNJdVzzVwd6lwW7+V8ZOZfig="; leaveDotGit = true; # fetch all submodules except unused triton submodule that is ~500MB postFetch = '' cd $out git reset --hard HEAD for submodule in $(git config --file .gitmodules --get-regexp path | awk '{print $2}' | grep '^third_party/' | grep -v '^third_party/triton$'); do git submodule update --init --recursive "$submodule" done find "$out" -name .git -print0 | xargs -0 rm -rf ''; }; cmakeBuildType = "RelWithDebInfo"; separateDebugInfo = true; __structuredAttrs = true; strictDeps = true; # Only set big-parallel when we are building kernels, no-image mode build is faster requiredSystemFeatures = if anySupportedTargets then [ "big-parallel" ] else [ ]; env = { AOTRITON_CI_SUPPLIED_SHA1 = finalAttrs.version; ROCM_PATH = "${clr}"; CFLAGS = "-w -g1 -gz -Wno-c++11-narrowing"; CXXFLAGS = finalAttrs.env.CFLAGS; TRITON_STORE_BINARY_ONLY = 1; # reduce triton disk space usage }; nativeBuildInputs = [ cmake rocm-cmake pkg-config python3 ninja clr writableTmpDirAsHomeHook # venv wants to cache in ~ ]; buildInputs = [ clr xz ] ++ (with python3.pkgs; [ wheel packaging pyyaml numpy filelock iniconfig pluggy pybind11 pandas triton ]); # Excerpt from README: # Note: do not run ninja separately, due to the limit of the current build system, # ninja install will run the whole build process unconditionally. dontBuild = true; # This builds+installs installPhase = '' runHook preInstall ninja -v install runHook postInstall ''; # tests are intended to be ran manually as test/ python scripts and need accelerator doCheck = false; doInstallCheck = false; # Need to set absolute paths to VENV and its PYTHON or # build fails with "AOTRITON_INHERIT_SYSTEM_SITE_TRITON is enabled # but triton is not available … no such file or directory" # Set via a preConfigure hook so a valid absolute path can be # picked if nix-shell is used against this package preConfigure = '' cmakeFlagsArray+=( "-DVENV_DIR=$(pwd)/build/venv/" "-DVENV_BIN_PYTHON=$(pwd)/build/venv/bin/python" ) ''; cmakeFlags = [ # Disable building kernels if no supported targets are enabled (lib.cmakeBool "AOTRITON_NOIMAGE_MODE" (!anySupportedTargets)) # Use preinstalled triton from our python's site-packages (lib.cmakeBool "AOTRITON_INHERIT_SYSTEM_SITE_TRITON" true) # Circular dependency (lib.cmakeBool "AOTRITON_USE_TORCH" false) # FP32 kernels are optional, turn them off to speed up builds and save space # Perf sensitive code should be using BF16 or F16 (lib.cmakeBool "AOTRITON_ENABLE_FP32_INPUTS" false) # Avoid kernels being skipped if build host is overloaded (lib.cmakeFeature "AOTRITON_GPU_BUILD_TIMEOUT" "0") # Manually define CMAKE_INSTALL_ # See: https://github.com/NixOS/nixpkgs/pull/197838 (lib.cmakeFeature "CMAKE_INSTALL_BINDIR" "bin") (lib.cmakeFeature "CMAKE_INSTALL_LIBDIR" "lib") (lib.cmakeFeature "CMAKE_INSTALL_INCLUDEDIR" "include") # Note: build will warn "AMDGPU_TARGETS was not set, and system GPU detection was unsuccsesful." # but this can safely be ignored, aotriton uses a different approach to pass targets (lib.cmakeFeature "AOTRITON_TARGET_ARCH" supportedTargets') ]; passthru.tests = { # regression test that aotriton so doesn't crash in static constructor # currently known to fail on rocm toolchain but fine with default stdenv ld-preload-into-hello = stdenv.mkDerivation { name = "aotriton-basic-load-test"; nativeBuildInputs = [ hello ]; buildCommand = '' set -e LD_PRELOAD=${ aotriton.override { gpuTargets = [ ]; } }/lib/libaotriton_v2.so ${hello}/bin/hello > /dev/null echo "ld-preload-into-hello" > $out ''; }; }; meta = { description = "ROCm Ahead of Time (AOT) Triton Math Library"; homepage = "https://github.com/ROCm/aotriton"; license = lib.licenses.mit; teams = [ lib.teams.rocm ]; platforms = lib.platforms.linux; # ld: error: unable to insert .comment after .comment broken = stdenv.cc.isClang; }; })