diff --git a/tensilelite/Tensile/SolutionStructs/Naming.py b/tensilelite/Tensile/SolutionStructs/Naming.py index 4f220960db1d..99535e246650 100644 --- a/tensilelite/Tensile/SolutionStructs/Naming.py +++ b/tensilelite/Tensile/SolutionStructs/Naming.py @@ -105,7 +105,6 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna if splitGSU: state["GlobalSplitU"] = "M" if (state["GlobalSplitU"] > 1 or state["GlobalSplitU"] == -1) else state["GlobalSplitU"] - requiredParametersTemp = set(requiredParameters.union(["GlobalSplitU"])) if ignoreInternalArgs: diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py index bab8c687509..e03f456fbec 100644 --- a/tensilelite/Tensile/CustomYamlLoader.py +++ b/tensilelite/Tensile/CustomYamlLoader.py @@ -1,3 +1,6 @@ +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import yaml from pathlib import Path Author: Luna Nova Date: Sun Oct 12 11:52:10 2025 -0700 [hipblaslt] intern strings to reduce duplicate memory for solution keys diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py index 685e69220c..9fdf38d8e5 100644 --- a/tensilelite/Tensile/CustomYamlLoader.py +++ b/tensilelite/Tensile/CustomYamlLoader.py @@ -1,6 +1,7 @@ # Copyright © Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +import sys import yaml from pathlib import Path @@ -85,7 +86,7 @@ def parse_scalar(loader: yaml.Loader): if not evt.style: return None - return value + return sys.intern(value) def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader): with open(yaml_path, 'r') as f: diff --git a/tensilelite/Tensile/Common/Parallel.py b/tensilelite/Tensile/Common/Parallel.py index 1a2bf9e119..f46100c7b8 100644 --- a/tensilelite/Tensile/Common/Parallel.py +++ b/tensilelite/Tensile/Common/Parallel.py @@ -22,43 +22,58 @@ # ################################################################################ -import concurrent.futures -import itertools +import multiprocessing import os +import re import sys import time - -from joblib import Parallel, delayed +from functools import partial +from typing import Any, Callable from .Utilities import tqdm -def joblibParallelSupportsGenerator(): - import joblib - from packaging.version import Version +def get_inherited_job_limit() -> int: + # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+) + if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ: + try: + return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL']) + except ValueError: + pass - joblibVer = joblib.__version__ - return Version(joblibVer) >= Version("1.4.0") + # 2. Parse MAKEFLAGS for -jN + makeflags = os.environ.get('MAKEFLAGS', '') + match = re.search(r'-j\s*(\d+)', makeflags) + if match: + return int(match.group(1)) + return -1 -def CPUThreadCount(enable=True): - from .GlobalParameters import globalParameters +def CPUThreadCount(enable=True): if not enable: return 1 - else: + from .GlobalParameters import globalParameters + + # Priority order: + # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS) + # 2. Explicit --jobs flag + # 3. Auto-detect + inherited_limit = get_inherited_job_limit() + cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"] + + if cpuThreads < 1: if os.name == "nt": - # Windows supports at most 61 workers because the scheduler uses - # WaitForMultipleObjects directly, which has the limit (the limit - # is actually 64, but some handles are needed for accounting). - cpu_count = min(os.cpu_count(), 61) + cpuThreads = os.cpu_count() else: - cpu_count = len(os.sched_getaffinity(0)) - cpuThreads = globalParameters["CpuThreads"] - if cpuThreads == -1: - return cpu_count + cpuThreads = len(os.sched_getaffinity(0)) - return min(cpu_count, cpuThreads) + if os.name == "nt": + # Windows supports at most 61 workers because the scheduler uses + # WaitForMultipleObjects directly, which has the limit (the limit + # is actually 64, but some handles are needed for accounting). + cpuThreads = min(cpuThreads, 61) + return max(1, cpuThreads) def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters): @@ -71,19 +86,22 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters): return f(arg) -def apply_print_exception(item, *args): - # print(item, args) +def OverwriteGlobalParameters(newGlobalParameters): + from . import GlobalParameters + + GlobalParameters.globalParameters.clear() + GlobalParameters.globalParameters.update(newGlobalParameters) + + +def worker_function(args, function, multiArg): + """Worker function that executes in the pool process.""" try: - if len(args) > 0: - func = item - args = args[0] - return func(*args) + if multiArg: + return function(*args) else: - func, item = item - return func(item) + return function(args) except Exception: import traceback - traceback.print_exc() raise finally: @@ -98,154 +116,121 @@ def OverwriteGlobalParameters(newGlobalParameters): GlobalParameters.globalParameters.update(newGlobalParameters) -def ProcessingPool(enable=True, maxTasksPerChild=None): - import multiprocessing - import multiprocessing.dummy - - threadCount = CPUThreadCount() - - if (not enable) or threadCount <= 1: - return multiprocessing.dummy.Pool(1) - - if multiprocessing.get_start_method() == "spawn": - from . import GlobalParameters - - return multiprocessing.Pool( - threadCount, - initializer=OverwriteGlobalParameters, - maxtasksperchild=maxTasksPerChild, - initargs=(GlobalParameters.globalParameters,), - ) - else: - return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild) +def progress_logger(iterable, total, message, min_log_interval=5.0): + """ + Generator that wraps an iterable and logs progress with time-based throttling. + Only logs progress if at least min_log_interval seconds have passed since last log. + Only prints completion message if task took >= min_log_interval seconds. -def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None): + Yields (index, item) tuples. """ - Generally equivalent to list(map(function, objects)), possibly executing in parallel. - - message: A message describing the operation to be performed. - enable: May be set to false to disable parallelism. - method: A function which can fetch the mapping function from a processing pool object. - Leave blank to use .map(), other possiblities: - - `lambda x: x.starmap` - useful if `function` takes multiple parameters. - - `lambda x: x.imap` - lazy evaluation - - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value. - """ - from .GlobalParameters import globalParameters + start_time = time.time() + last_log_time = start_time + log_interval = 1 + (total // 100) - threadCount = CPUThreadCount(enable) - pool = ProcessingPool(enable, maxTasksPerChild) - - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - # This works for method=None, and for starmap. - mapFunc = map - if method is not None: - # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7. - # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded - # operation. - try: - mapFunc = method(itertools) - except NameError: - mapFunc = None - - if mapFunc is not None: - return list(mapFunc(function, tqdm(objects, message))) - - mapFunc = pool.map - if method: - mapFunc = method(pool) - - objects = zip(itertools.repeat(function), objects) - function = apply_print_exception - - countMessage = "" - try: - countMessage = " for {} tasks".format(len(objects)) - except TypeError: - pass + for idx, item in enumerate(iterable): + if idx % log_interval == 0: + current_time = time.time() + if (current_time - last_log_time) >= min_log_interval: + print(f"{message}\t{idx+1: 5d}/{total: 5d}") + last_log_time = current_time + yield idx, item - if message != "": - message += ": " + elapsed = time.time() - start_time + final_idx = idx + 1 if 'idx' in locals() else 0 - print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) - sys.stdout.flush() - currentTime = time.time() - rv = mapFunc(function, objects) - totalTime = time.time() - currentTime - print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) - sys.stdout.flush() - pool.close() - return rv + if elapsed >= min_log_interval or last_log_time > start_time: + print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}") -def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True): - from .GlobalParameters import globalParameters +def imap_with_progress(pool, func, iterable, total, message, chunksize): + results = [] + for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message): + results.append(result) + return results - threadCount = CPUThreadCount(enable) - print("{0}Launching {1} threads...".format(message, threadCount)) - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - callFunc = lambda args: function(*args) if multiArg else lambda args: function(args) - return [callFunc(args) for args in tqdm(objects, message)] +def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild): + # separate fn because yield makes the entire fn a generator even if unreachable + ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') - with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor: - resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects) - for result in concurrent.futures.as_completed(resultFutures): - yield result.result() + with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, + initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: + for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message): + yield result def ParallelMap2( - function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None + function: Callable, + objects: Any, + message: str = "", + enable: bool = True, + multiArg: bool = True, + minChunkSize: int = 1, + maxWorkers: int = -1, + maxtasksperchild: int = 1024, + return_as: str = "list" ): + """Executes a function over a list of objects in parallel or sequentially. + + This function is generally equivalent to ``list(map(function, objects))``. However, it provides + additional functionality to run in parallel, depending on the 'enable' flag and available CPU + threads. + + Args: + function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function' + should accept multiple arguments. + objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each + item in 'objects' should be an iterable of arguments for 'function'. + message: Optional; a message describing the operation. Default is an empty string. + enable: Optional; if False, disables parallel execution and runs sequentially. Default is True. + multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for + 'function'. Default is True. + return_as: Optional; "list" (default) or "generator_unordered" for streaming results + + Returns: + A list or generator containing the results of applying **function** to each item in **objects**. """ - Generally equivalent to list(map(function, objects)), possibly executing in parallel. + from .GlobalParameters import globalParameters - message: A message describing the operation to be performed. - enable: May be set to false to disable parallelism. - multiArg: True if objects represent multiple arguments - (differentiates multi args vs single collection arg) - """ - if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator(): - return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg) + threadCount = CPUThreadCount(enable) - from .GlobalParameters import globalParameters + if not hasattr(objects, "__len__"): + objects = list(objects) - threadCount = procs if procs else CPUThreadCount(enable) + objLen = len(objects) + if objLen == 0: + return [] if return_as == "list" else iter([]) - threadCount = CPUThreadCount(enable) + f = (lambda x: function(*x)) if multiArg else function + if objLen == 1: + print(f"{message}: (1 task)") + result = [f(x) for x in objects] + return result if return_as == "list" else iter(result) - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)] + extra_message = ( + f": {threadCount} thread(s)" + f", {objLen} tasks" + if objLen + else "" + ) - countMessage = "" - try: - countMessage = " for {} tasks".format(len(objects)) - except TypeError: - pass - - if message != "": - message += ": " - print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) - sys.stdout.flush() - currentTime = time.time() - - pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg - pargs = zip(objects, itertools.repeat(globalParameters)) - - if joblibParallelSupportsGenerator(): - rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)( - delayed(pcall)(function, a, params) for a, params in pargs - ) + print(f"ParallelMap {message}{extra_message}") + + if threadCount <= 1: + result = [f(x) for x in objects] + return result if return_as == "list" else iter(result) + + if maxWorkers > 0: + threadCount = min(maxWorkers, threadCount) + + chunksize = max(minChunkSize, objLen // 2000) + worker = partial(worker_function, function=function, multiArg=multiArg) + if return_as == "generator_unordered": + # yield results as they complete without buffering + return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild) else: - rv = Parallel(n_jobs=threadCount, timeout=99999)( - delayed(pcall)(function, a, params) for a, params in pargs - ) - - totalTime = time.time() - currentTime - print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) - sys.stdout.flush() - return rv + ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') + with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, + initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: + return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize)) diff --git a/tensilelite/Tensile/CustomKernels.py b/tensilelite/Tensile/CustomKernels.py index ffceb636f5..127b3386a1 100644 --- a/tensilelite/Tensile/CustomKernels.py +++ b/tensilelite/Tensile/CustomKernels.py @@ -24,7 +24,9 @@ from . import CUSTOM_KERNEL_PATH from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters +from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER +from functools import lru_cache import yaml import os @@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH): return (config, assembly) +# getCustomKernelConfig will get called repeatedly on the same file +# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs +@lru_cache def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH): rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory) try: - return yaml.safe_load(rawConfig)["custom.config"] + return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"] except yaml.scanner.ScannerError as e: raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e)) diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py index 835ed9c019..024c6c49c1 100644 --- a/tensilelite/Tensile/TensileCreateLibrary/Run.py +++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py @@ -26,8 +26,10 @@ import rocisa import functools import glob +import gc import itertools import os +import resource import shutil from pathlib import Path from timeit import default_timer as timer @@ -78,6 +80,25 @@ from Tensile.Utilities.Decorators.Timing import timing from .ParseArguments import parseArguments +def getMemoryUsage(): + """Get peak and current memory usage in MB.""" + rusage = resource.getrusage(resource.RUSAGE_SELF) + peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux + + # Get current memory from /proc/self/status + current_memory_mb = 0 + try: + with open('/proc/self/status') as f: + for line in f: + if line.startswith('VmRSS:'): + current_memory_mb = int(line.split()[1]) / 1024 # KB to MB + break + except: + current_memory_mb = peak_memory_mb # Fallback + + return (peak_memory_mb, current_memory_mb) + + class KernelCodeGenResult(NamedTuple): err: int src: str @@ -115,6 +136,29 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC ) +def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, splitGSU, kernel, assemblyTmpPath, assembler): + """ + Pipeline function for TCL mode that: + 1. Generates kernel source + 2. Writes .s file to disk + 3. Assembles to .o file + 4. Deletes .s file + """ + result = processKernelSource(kernelWriterAssembly, rocisa_data, splitGSU, kernel) + return writeAndAssembleKernel(result, assemblyTmpPath, assembler) + + +def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat): + """ + Write a master solution library to disk. + Module-level function to support multiprocessing. + """ + name, lib = name_lib_tuple + filename = os.path.join(newLibraryDir, name) + lib.applyNaming(splitGSU) + LibraryIO.write(filename, state(lib), libraryFormat) + + def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool): removeKernels = [] removeKernelNames = [] @@ -189,6 +233,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult): return path, isa, wfsize, minResult +def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler): + """Write assembly file and immediately assemble it to .o file""" + if result.err: + printExit(f"Failed to build kernel {result.name} because it has error code {result.err}") + + path = Path(asmPath) / f"{result.name}.s" + with open(path, "w", encoding="utf-8") as f: + f.write(result.src) + + # Assemble .s -> .o + assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o"))) + + # Delete assembly file immediately to save disk space + path.unlink() + + return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk) + + def writeHelpers( outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H ): @@ -268,13 +330,14 @@ def writeSolutionsAndKernels( numAsmKernels = len(asmKernels) numKernels = len(asmKernels) assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite" - asmIter = zip( - itertools.repeat(kernelWriterAssembly), - itertools.repeat(rocisa.rocIsa.getInstance().getData()), - itertools.repeat(splitGSU), - asmKernels + + processKernelFn = functools.partial( + processKernelSource, + kernelWriterAssembly=kernelWriterAssembly, + data=rocisa.rocIsa.getInstance().getData(), + splitGSU=splitGSU ) - asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list") + asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False) removeInvalidSolutionsAndKernels( asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU ) @@ -282,19 +345,21 @@ def writeSolutionsAndKernels( asmResults, asmKernels, solutions, splitGSU ) - def assemble(ret): - p, isa, wavefrontsize, result = ret - asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) - - unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) - compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) + # Use functools.partial to bind assemblyTmpPath and assembler + writeAndAssembleFn = functools.partial( + writeAndAssembleKernel, + asmPath=assemblyTmpPath, + assembler=asmToolchain.assembler + ) ret = ParallelMap2( - compose(assemble, unaryWriteAssembly), + writeAndAssembleFn, asmResults, "Writing assembly kernels", return_as="list", multiArg=False, ) + del asmResults + gc.collect() writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H) srcKernelFile = Path(outputPath) / "Kernels.cpp" @@ -369,32 +434,31 @@ def writeSolutionsAndKernelsTCL( uniqueAsmKernels = [k for k in asmKernels if not k.duplicate] - def assemble(ret): - p, isa, wavefrontsize, result = ret - asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) - return result - - unaryProcessKernelSource = functools.partial( - processKernelSource, + processKernelFn = functools.partial( + processAndAssembleKernelTCL, kernelWriterAssembly, rocisa.rocIsa.getInstance().getData(), splitGSU, + assemblyTmpPath=assemblyTmpPath, + assembler=asmToolchain.assembler ) - unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) - compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) - ret = ParallelMap2( - compose(assemble, unaryWriteAssembly, unaryProcessKernelSource), + results = ParallelMap2( + processKernelFn, uniqueAsmKernels, "Generating assembly kernels", multiArg=False, return_as="list" ) + del processKernelFn + gc.collect() + passPostKernelInfoToSolution( - ret, uniqueAsmKernels, solutions, splitGSU + results, uniqueAsmKernels, solutions, splitGSU ) - # result.src is very large so let garbage collector know to clean up - del ret + del results + gc.collect() + buildAssemblyCodeObjectFiles( asmToolchain.linker, asmToolchain.bundler, @@ -493,6 +557,15 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list +def libraryIter(lib: MasterSolutionLibrary): + if len(lib.solutions): + for i, s in enumerate(lib.solutions.items()): + yield (i, *s) + else: + for _, lazyLib in lib.lazyLibraries.items(): + yield from libraryIter(lazyLib) + + @timing def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap): @@ -508,26 +581,23 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf printSolutionRejectionReason = True printIndexAssignmentInfo = False - fIter = zip( - logicFiles, - itertools.repeat(assembler), - itertools.repeat(splitGSU), - itertools.repeat(printSolutionRejectionReason), - itertools.repeat(printIndexAssignmentInfo), - itertools.repeat(isaInfoMap), - itertools.repeat(args["LazyLibraryLoading"]), + parseLogicFn = functools.partial( + LibraryIO.parseLibraryLogicFile, + assembler=assembler, + splitGSU=splitGSU, + printSolutionRejectionReason=printSolutionRejectionReason, + printIndexAssignmentInfo=printIndexAssignmentInfo, + isaInfoMap=isaInfoMap, + lazyLibraryLoading=args["LazyLibraryLoading"] ) - def libraryIter(lib: MasterSolutionLibrary): - if len(lib.solutions): - for i, s in enumerate(lib.solutions.items()): - yield (i, *s) - else: - for _, lazyLib in lib.lazyLibraries.items(): - yield from libraryIter(lazyLib) - for library in ParallelMap2( - LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered" + parseLogicFn, logicFiles, "Loading Logics...", + return_as="generator_unordered", + minChunkSize=24, + maxWorkers=32, + maxtasksperchild=1, + multiArg=False, ): _, architectureName, _, _, _, newLibrary = library @@ -539,6 +609,9 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf else: masterLibraries[architectureName] = newLibrary masterLibraries[architectureName].version = args["CodeObjectVersion"] + del library, newLibrary + + gc.collect() # Sort masterLibraries to make global soln index values deterministic solnReIndex = 0 @@ -734,6 +807,9 @@ def run(): ) stop_wsk = timer() print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}") + numKernelHelperObjs = len(kernelHelperObjs) + del kernelWriterAssembly, kernelHelperObjs + gc.collect() archs = [ # is this really different than the other archs above? isaToGfx(arch) @@ -751,13 +827,10 @@ def run(): if kName not in solDict: solDict["%s"%kName] = kernel - def writeMsl(name, lib): - filename = os.path.join(newLibraryDir, name) - lib.applyNaming(splitGSU) - LibraryIO.write(filename, state(lib), arguments["LibraryFormat"]) - filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping") LibraryIO.write(filename, libraryMapping, "msgpack") + del libraryMapping + gc.collect() start_msl = timer() for archName, newMasterLibrary in masterLibraries.items(): @@ -774,12 +847,22 @@ def run(): kName = getKeyNoInternalArgs(s.originalSolution, splitGSU) s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"] - ParallelMap2(writeMsl, + writeFn = functools.partial( + writeMasterSolutionLibrary, + newLibraryDir=newLibraryDir, + splitGSU=splitGSU, + libraryFormat=arguments["LibraryFormat"] + ) + + ParallelMap2(writeFn, newMasterLibrary.lazyLibraries.items(), "Writing master solution libraries", + multiArg=False, return_as="list") stop_msl = timer() print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}") + del masterLibraries, solutions, kernels, solDict + gc.collect() if not arguments["KeepBuildTmp"]: buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp" @@ -796,8 +879,11 @@ def run(): print("") stop = timer() + peak_memory_mb, current_memory_mb = getMemoryUsage() print(f"Total time (s): {(stop-start):3.2f}") print(f"Total kernels processed: {numKernels}") print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}") - print(f"KernelHelperObjs: {len(kernelHelperObjs)}") + print(f"KernelHelperObjs: {numKernelHelperObjs}") + print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}") + print(f"Current memory usage (MB): {current_memory_mb:,.1f}") diff --git a/tensilelite/Tensile/TensileMergeLibrary.py b/tensilelite/Tensile/TensileMergeLibrary.py index e33c617b6f..ba163e9918 100644 --- a/tensilelite/Tensile/TensileMergeLibrary.py +++ b/tensilelite/Tensile/TensileMergeLibrary.py @@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff= logicsFiles[origFile] = origFile logicsFiles[incFile] = incFile - iters = zip(logicsFiles.keys()) - logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list") + logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False) logicsDict = {} for i, _ in enumerate(logicsList): logicsDict[logicsList[i][0]] = logicsList[i][1] diff --git a/tensilelite/Tensile/TensileUpdateLibrary.py b/tensilelite/Tensile/TensileUpdateLibrary.py index 5ff265d0ed..c1803a6349 100644 --- a/tensilelite/Tensile/TensileUpdateLibrary.py +++ b/tensilelite/Tensile/TensileUpdateLibrary.py @@ -26,7 +26,7 @@ from . import LibraryIO from .Tensile import addCommonArguments, argUpdatedGlobalParameters from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \ - globalParameters, architectureMap, ensurePath, ParallelMap, __version__ + globalParameters, architectureMap, ensurePath, ParallelMap2, __version__ import argparse import copy @@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs): for logicFile in logicFiles: print("# %s" % logicFile) fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath)) - libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap) + libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list") def main(): diff --git a/tensilelite/Tensile/Toolchain/Assembly.py b/tensilelite/Tensile/Toolchain/Assembly.py index a8b91e8d62..265e1d532c 100644 --- a/tensilelite/Tensile/Toolchain/Assembly.py +++ b/tensilelite/Tensile/Toolchain/Assembly.py @@ -30,7 +30,7 @@ import subprocess from pathlib import Path from typing import List, Union, NamedTuple -from Tensile.Common import print2 +from Tensile.Common import print1, print2 from Tensile.Common.Architectures import isaToGfx from ..SolutionStructs import Solution @@ -92,8 +92,26 @@ def buildAssemblyCodeObjectFiles( if coName: coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj))) + # Build reference count map for .o files to handle shared object files + # (.o files from kernels marked .duplicate in TensileCreateLibrary) + objFileRefCount = collections.Counter() + for coFileRaw, objFiles in coFileMap.items(): + for objFile in objFiles: + objFileRefCount[objFile] += 1 + + sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1} + if sharedObjFiles: + print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:") + for coFileRaw, objFiles in coFileMap.items(): linker(objFiles, str(coFileRaw)) + + # Delete .o files after linking once usage count reaches 0 + for objFile in objFiles: + objFileRefCount[objFile] -= 1 + if objFileRefCount[objFile] == 0: + Path(objFile).unlink() + coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo) if compress: bundler.compress(str(coFileRaw), str(coFile), gfx) diff --git a/tensilelite/Tensile/Toolchain/Component.py b/tensilelite/Tensile/Toolchain/Component.py index 67fa35e2d8..dde83af4c3 100644 --- a/tensilelite/Tensile/Toolchain/Component.py +++ b/tensilelite/Tensile/Toolchain/Component.py @@ -355,6 +355,7 @@ class Linker(Component): when invoking the linker, LLVM allows the provision of arguments via a "response file" Reference: https://llvm.org/docs/CommandLine.html#response-files """ + # FIXME: this prevents threading as clang_args.txt is overwritten with open(Path.cwd() / "clang_args.txt", "wt") as file: file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths)) return [*(self.default_args), "-o", destPath, "@clang_args.txt"] diff --git a/tensilelite/requirements.txt b/tensilelite/requirements.txt index 60c4c11445..5c8fd66a88 100644 --- a/tensilelite/requirements.txt +++ b/tensilelite/requirements.txt @@ -2,8 +2,6 @@ dataclasses; python_version == '3.6' packaging pyyaml msgpack -joblib>=1.4.0; python_version >= '3.8' -joblib>=1.1.1; python_version < '3.8' simplejson ujson orjson