nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 868 lines 35 kB view raw
1diff --git a/tensilelite/Tensile/SolutionStructs/Naming.py b/tensilelite/Tensile/SolutionStructs/Naming.py 2index 4f220960db1d..99535e246650 100644 3--- a/tensilelite/Tensile/SolutionStructs/Naming.py 4+++ b/tensilelite/Tensile/SolutionStructs/Naming.py 5@@ -105,7 +105,6 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna 6 if splitGSU: 7 state["GlobalSplitU"] = "M" if (state["GlobalSplitU"] > 1 or state["GlobalSplitU"] == -1) else state["GlobalSplitU"] 8 9- 10 requiredParametersTemp = set(requiredParameters.union(["GlobalSplitU"])) 11 12 if ignoreInternalArgs: 13diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py 14index bab8c687509..e03f456fbec 100644 15--- a/tensilelite/Tensile/CustomYamlLoader.py 16+++ b/tensilelite/Tensile/CustomYamlLoader.py 17@@ -1,3 +1,6 @@ 18+# Copyright © Advanced Micro Devices, Inc., or its affiliates. 19+# SPDX-License-Identifier: MIT 20+ 21 import yaml 22 from pathlib import Path 23 24Author: Luna Nova <git@lunnova.dev> 25Date: Sun Oct 12 11:52:10 2025 -0700 26 27 [hipblaslt] intern strings to reduce duplicate memory for solution keys 28 29diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py 30index 685e69220c..9fdf38d8e5 100644 31--- a/tensilelite/Tensile/CustomYamlLoader.py 32+++ b/tensilelite/Tensile/CustomYamlLoader.py 33@@ -1,6 +1,7 @@ 34 # Copyright © Advanced Micro Devices, Inc., or its affiliates. 35 # SPDX-License-Identifier: MIT 36 37+import sys 38 import yaml 39 from pathlib import Path 40 41@@ -85,7 +86,7 @@ def parse_scalar(loader: yaml.Loader): 42 if not evt.style: 43 return None 44 45- return value 46+ return sys.intern(value) 47 48 def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader): 49 with open(yaml_path, 'r') as f: 50 51diff --git a/tensilelite/Tensile/Common/Parallel.py b/tensilelite/Tensile/Common/Parallel.py 52index 1a2bf9e119..f46100c7b8 100644 53--- a/tensilelite/Tensile/Common/Parallel.py 54+++ b/tensilelite/Tensile/Common/Parallel.py 55@@ -22,43 +22,58 @@ 56 # 57 ################################################################################ 58 59-import concurrent.futures 60-import itertools 61+import multiprocessing 62 import os 63+import re 64 import sys 65 import time 66- 67-from joblib import Parallel, delayed 68+from functools import partial 69+from typing import Any, Callable 70 71 from .Utilities import tqdm 72 73 74-def joblibParallelSupportsGenerator(): 75- import joblib 76- from packaging.version import Version 77+def get_inherited_job_limit() -> int: 78+ # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+) 79+ if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ: 80+ try: 81+ return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL']) 82+ except ValueError: 83+ pass 84 85- joblibVer = joblib.__version__ 86- return Version(joblibVer) >= Version("1.4.0") 87+ # 2. Parse MAKEFLAGS for -jN 88+ makeflags = os.environ.get('MAKEFLAGS', '') 89+ match = re.search(r'-j\s*(\d+)', makeflags) 90+ if match: 91+ return int(match.group(1)) 92 93+ return -1 94 95-def CPUThreadCount(enable=True): 96- from .GlobalParameters import globalParameters 97 98+def CPUThreadCount(enable=True): 99 if not enable: 100 return 1 101- else: 102+ from .GlobalParameters import globalParameters 103+ 104+ # Priority order: 105+ # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS) 106+ # 2. Explicit --jobs flag 107+ # 3. Auto-detect 108+ inherited_limit = get_inherited_job_limit() 109+ cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"] 110+ 111+ if cpuThreads < 1: 112 if os.name == "nt": 113- # Windows supports at most 61 workers because the scheduler uses 114- # WaitForMultipleObjects directly, which has the limit (the limit 115- # is actually 64, but some handles are needed for accounting). 116- cpu_count = min(os.cpu_count(), 61) 117+ cpuThreads = os.cpu_count() 118 else: 119- cpu_count = len(os.sched_getaffinity(0)) 120- cpuThreads = globalParameters["CpuThreads"] 121- if cpuThreads == -1: 122- return cpu_count 123+ cpuThreads = len(os.sched_getaffinity(0)) 124 125- return min(cpu_count, cpuThreads) 126+ if os.name == "nt": 127+ # Windows supports at most 61 workers because the scheduler uses 128+ # WaitForMultipleObjects directly, which has the limit (the limit 129+ # is actually 64, but some handles are needed for accounting). 130+ cpuThreads = min(cpuThreads, 61) 131+ return max(1, cpuThreads) 132 133 134 def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters): 135@@ -71,19 +86,22 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters): 136 return f(arg) 137 138 139-def apply_print_exception(item, *args): 140- # print(item, args) 141+def OverwriteGlobalParameters(newGlobalParameters): 142+ from . import GlobalParameters 143+ 144+ GlobalParameters.globalParameters.clear() 145+ GlobalParameters.globalParameters.update(newGlobalParameters) 146+ 147+ 148+def worker_function(args, function, multiArg): 149+ """Worker function that executes in the pool process.""" 150 try: 151- if len(args) > 0: 152- func = item 153- args = args[0] 154- return func(*args) 155+ if multiArg: 156+ return function(*args) 157 else: 158- func, item = item 159- return func(item) 160+ return function(args) 161 except Exception: 162 import traceback 163- 164 traceback.print_exc() 165 raise 166 finally: 167@@ -98,154 +116,121 @@ def OverwriteGlobalParameters(newGlobalParameters): 168 GlobalParameters.globalParameters.update(newGlobalParameters) 169 170 171-def ProcessingPool(enable=True, maxTasksPerChild=None): 172- import multiprocessing 173- import multiprocessing.dummy 174- 175- threadCount = CPUThreadCount() 176- 177- if (not enable) or threadCount <= 1: 178- return multiprocessing.dummy.Pool(1) 179- 180- if multiprocessing.get_start_method() == "spawn": 181- from . import GlobalParameters 182- 183- return multiprocessing.Pool( 184- threadCount, 185- initializer=OverwriteGlobalParameters, 186- maxtasksperchild=maxTasksPerChild, 187- initargs=(GlobalParameters.globalParameters,), 188- ) 189- else: 190- return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild) 191+def progress_logger(iterable, total, message, min_log_interval=5.0): 192+ """ 193+ Generator that wraps an iterable and logs progress with time-based throttling. 194 195+ Only logs progress if at least min_log_interval seconds have passed since last log. 196+ Only prints completion message if task took >= min_log_interval seconds. 197 198-def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None): 199+ Yields (index, item) tuples. 200 """ 201- Generally equivalent to list(map(function, objects)), possibly executing in parallel. 202- 203- message: A message describing the operation to be performed. 204- enable: May be set to false to disable parallelism. 205- method: A function which can fetch the mapping function from a processing pool object. 206- Leave blank to use .map(), other possiblities: 207- - `lambda x: x.starmap` - useful if `function` takes multiple parameters. 208- - `lambda x: x.imap` - lazy evaluation 209- - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value. 210- """ 211- from .GlobalParameters import globalParameters 212+ start_time = time.time() 213+ last_log_time = start_time 214+ log_interval = 1 + (total // 100) 215 216- threadCount = CPUThreadCount(enable) 217- pool = ProcessingPool(enable, maxTasksPerChild) 218- 219- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 220- # Provide a progress bar for single-threaded operation. 221- # This works for method=None, and for starmap. 222- mapFunc = map 223- if method is not None: 224- # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7. 225- # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded 226- # operation. 227- try: 228- mapFunc = method(itertools) 229- except NameError: 230- mapFunc = None 231- 232- if mapFunc is not None: 233- return list(mapFunc(function, tqdm(objects, message))) 234- 235- mapFunc = pool.map 236- if method: 237- mapFunc = method(pool) 238- 239- objects = zip(itertools.repeat(function), objects) 240- function = apply_print_exception 241- 242- countMessage = "" 243- try: 244- countMessage = " for {} tasks".format(len(objects)) 245- except TypeError: 246- pass 247+ for idx, item in enumerate(iterable): 248+ if idx % log_interval == 0: 249+ current_time = time.time() 250+ if (current_time - last_log_time) >= min_log_interval: 251+ print(f"{message}\t{idx+1: 5d}/{total: 5d}") 252+ last_log_time = current_time 253+ yield idx, item 254 255- if message != "": 256- message += ": " 257+ elapsed = time.time() - start_time 258+ final_idx = idx + 1 if 'idx' in locals() else 0 259 260- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) 261- sys.stdout.flush() 262- currentTime = time.time() 263- rv = mapFunc(function, objects) 264- totalTime = time.time() - currentTime 265- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) 266- sys.stdout.flush() 267- pool.close() 268- return rv 269+ if elapsed >= min_log_interval or last_log_time > start_time: 270+ print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}") 271 272 273-def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True): 274- from .GlobalParameters import globalParameters 275+def imap_with_progress(pool, func, iterable, total, message, chunksize): 276+ results = [] 277+ for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message): 278+ results.append(result) 279+ return results 280 281- threadCount = CPUThreadCount(enable) 282- print("{0}Launching {1} threads...".format(message, threadCount)) 283 284- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 285- # Provide a progress bar for single-threaded operation. 286- callFunc = lambda args: function(*args) if multiArg else lambda args: function(args) 287- return [callFunc(args) for args in tqdm(objects, message)] 288+def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild): 289+ # separate fn because yield makes the entire fn a generator even if unreachable 290+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') 291 292- with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor: 293- resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects) 294- for result in concurrent.futures.as_completed(resultFutures): 295- yield result.result() 296+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, 297+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: 298+ for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message): 299+ yield result 300 301 302 def ParallelMap2( 303- function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None 304+ function: Callable, 305+ objects: Any, 306+ message: str = "", 307+ enable: bool = True, 308+ multiArg: bool = True, 309+ minChunkSize: int = 1, 310+ maxWorkers: int = -1, 311+ maxtasksperchild: int = 1024, 312+ return_as: str = "list" 313 ): 314+ """Executes a function over a list of objects in parallel or sequentially. 315+ 316+ This function is generally equivalent to ``list(map(function, objects))``. However, it provides 317+ additional functionality to run in parallel, depending on the 'enable' flag and available CPU 318+ threads. 319+ 320+ Args: 321+ function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function' 322+ should accept multiple arguments. 323+ objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each 324+ item in 'objects' should be an iterable of arguments for 'function'. 325+ message: Optional; a message describing the operation. Default is an empty string. 326+ enable: Optional; if False, disables parallel execution and runs sequentially. Default is True. 327+ multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for 328+ 'function'. Default is True. 329+ return_as: Optional; "list" (default) or "generator_unordered" for streaming results 330+ 331+ Returns: 332+ A list or generator containing the results of applying **function** to each item in **objects**. 333 """ 334- Generally equivalent to list(map(function, objects)), possibly executing in parallel. 335+ from .GlobalParameters import globalParameters 336 337- message: A message describing the operation to be performed. 338- enable: May be set to false to disable parallelism. 339- multiArg: True if objects represent multiple arguments 340- (differentiates multi args vs single collection arg) 341- """ 342- if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator(): 343- return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg) 344+ threadCount = CPUThreadCount(enable) 345 346- from .GlobalParameters import globalParameters 347+ if not hasattr(objects, "__len__"): 348+ objects = list(objects) 349 350- threadCount = procs if procs else CPUThreadCount(enable) 351+ objLen = len(objects) 352+ if objLen == 0: 353+ return [] if return_as == "list" else iter([]) 354 355- threadCount = CPUThreadCount(enable) 356+ f = (lambda x: function(*x)) if multiArg else function 357+ if objLen == 1: 358+ print(f"{message}: (1 task)") 359+ result = [f(x) for x in objects] 360+ return result if return_as == "list" else iter(result) 361 362- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 363- # Provide a progress bar for single-threaded operation. 364- return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)] 365+ extra_message = ( 366+ f": {threadCount} thread(s)" + f", {objLen} tasks" 367+ if objLen 368+ else "" 369+ ) 370 371- countMessage = "" 372- try: 373- countMessage = " for {} tasks".format(len(objects)) 374- except TypeError: 375- pass 376- 377- if message != "": 378- message += ": " 379- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) 380- sys.stdout.flush() 381- currentTime = time.time() 382- 383- pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg 384- pargs = zip(objects, itertools.repeat(globalParameters)) 385- 386- if joblibParallelSupportsGenerator(): 387- rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)( 388- delayed(pcall)(function, a, params) for a, params in pargs 389- ) 390+ print(f"ParallelMap {message}{extra_message}") 391+ 392+ if threadCount <= 1: 393+ result = [f(x) for x in objects] 394+ return result if return_as == "list" else iter(result) 395+ 396+ if maxWorkers > 0: 397+ threadCount = min(maxWorkers, threadCount) 398+ 399+ chunksize = max(minChunkSize, objLen // 2000) 400+ worker = partial(worker_function, function=function, multiArg=multiArg) 401+ if return_as == "generator_unordered": 402+ # yield results as they complete without buffering 403+ return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild) 404 else: 405- rv = Parallel(n_jobs=threadCount, timeout=99999)( 406- delayed(pcall)(function, a, params) for a, params in pargs 407- ) 408- 409- totalTime = time.time() - currentTime 410- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) 411- sys.stdout.flush() 412- return rv 413+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') 414+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, 415+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: 416+ return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize)) 417diff --git a/tensilelite/Tensile/CustomKernels.py b/tensilelite/Tensile/CustomKernels.py 418index ffceb636f5..127b3386a1 100644 419--- a/tensilelite/Tensile/CustomKernels.py 420+++ b/tensilelite/Tensile/CustomKernels.py 421@@ -24,7 +24,9 @@ 422 423 from . import CUSTOM_KERNEL_PATH 424 from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters 425+from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER 426 427+from functools import lru_cache 428 import yaml 429 430 import os 431@@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH): 432 433 return (config, assembly) 434 435+# getCustomKernelConfig will get called repeatedly on the same file 436+# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs 437+@lru_cache 438 def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH): 439 rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory) 440 try: 441- return yaml.safe_load(rawConfig)["custom.config"] 442+ return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"] 443 except yaml.scanner.ScannerError as e: 444 raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e)) 445 446diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py 447index 835ed9c019..024c6c49c1 100644 448--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py 449+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py 450@@ -26,8 +26,10 @@ import rocisa 451 452 import functools 453 import glob 454+import gc 455 import itertools 456 import os 457+import resource 458 import shutil 459 from pathlib import Path 460 from timeit import default_timer as timer 461@@ -78,6 +80,25 @@ from Tensile.Utilities.Decorators.Timing import timing 462 from .ParseArguments import parseArguments 463 464 465+def getMemoryUsage(): 466+ """Get peak and current memory usage in MB.""" 467+ rusage = resource.getrusage(resource.RUSAGE_SELF) 468+ peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux 469+ 470+ # Get current memory from /proc/self/status 471+ current_memory_mb = 0 472+ try: 473+ with open('/proc/self/status') as f: 474+ for line in f: 475+ if line.startswith('VmRSS:'): 476+ current_memory_mb = int(line.split()[1]) / 1024 # KB to MB 477+ break 478+ except: 479+ current_memory_mb = peak_memory_mb # Fallback 480+ 481+ return (peak_memory_mb, current_memory_mb) 482+ 483+ 484 class KernelCodeGenResult(NamedTuple): 485 err: int 486 src: str 487@@ -115,6 +136,29 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC 488 ) 489 490 491+def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, splitGSU, kernel, assemblyTmpPath, assembler): 492+ """ 493+ Pipeline function for TCL mode that: 494+ 1. Generates kernel source 495+ 2. Writes .s file to disk 496+ 3. Assembles to .o file 497+ 4. Deletes .s file 498+ """ 499+ result = processKernelSource(kernelWriterAssembly, rocisa_data, splitGSU, kernel) 500+ return writeAndAssembleKernel(result, assemblyTmpPath, assembler) 501+ 502+ 503+def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat): 504+ """ 505+ Write a master solution library to disk. 506+ Module-level function to support multiprocessing. 507+ """ 508+ name, lib = name_lib_tuple 509+ filename = os.path.join(newLibraryDir, name) 510+ lib.applyNaming(splitGSU) 511+ LibraryIO.write(filename, state(lib), libraryFormat) 512+ 513+ 514 def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool): 515 removeKernels = [] 516 removeKernelNames = [] 517@@ -189,6 +233,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult): 518 return path, isa, wfsize, minResult 519 520 521+def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler): 522+ """Write assembly file and immediately assemble it to .o file""" 523+ if result.err: 524+ printExit(f"Failed to build kernel {result.name} because it has error code {result.err}") 525+ 526+ path = Path(asmPath) / f"{result.name}.s" 527+ with open(path, "w", encoding="utf-8") as f: 528+ f.write(result.src) 529+ 530+ # Assemble .s -> .o 531+ assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o"))) 532+ 533+ # Delete assembly file immediately to save disk space 534+ path.unlink() 535+ 536+ return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk) 537+ 538+ 539 def writeHelpers( 540 outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H 541 ): 542@@ -268,13 +330,14 @@ def writeSolutionsAndKernels( 543 numAsmKernels = len(asmKernels) 544 numKernels = len(asmKernels) 545 assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite" 546- asmIter = zip( 547- itertools.repeat(kernelWriterAssembly), 548- itertools.repeat(rocisa.rocIsa.getInstance().getData()), 549- itertools.repeat(splitGSU), 550- asmKernels 551+ 552+ processKernelFn = functools.partial( 553+ processKernelSource, 554+ kernelWriterAssembly=kernelWriterAssembly, 555+ data=rocisa.rocIsa.getInstance().getData(), 556+ splitGSU=splitGSU 557 ) 558- asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list") 559+ asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False) 560 removeInvalidSolutionsAndKernels( 561 asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU 562 ) 563@@ -282,19 +345,21 @@ def writeSolutionsAndKernels( 564 asmResults, asmKernels, solutions, splitGSU 565 ) 566 567- def assemble(ret): 568- p, isa, wavefrontsize, result = ret 569- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) 570- 571- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) 572- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) 573+ # Use functools.partial to bind assemblyTmpPath and assembler 574+ writeAndAssembleFn = functools.partial( 575+ writeAndAssembleKernel, 576+ asmPath=assemblyTmpPath, 577+ assembler=asmToolchain.assembler 578+ ) 579 ret = ParallelMap2( 580- compose(assemble, unaryWriteAssembly), 581+ writeAndAssembleFn, 582 asmResults, 583 "Writing assembly kernels", 584 return_as="list", 585 multiArg=False, 586 ) 587+ del asmResults 588+ gc.collect() 589 590 writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H) 591 srcKernelFile = Path(outputPath) / "Kernels.cpp" 592@@ -369,32 +434,31 @@ def writeSolutionsAndKernelsTCL( 593 594 uniqueAsmKernels = [k for k in asmKernels if not k.duplicate] 595 596- def assemble(ret): 597- p, isa, wavefrontsize, result = ret 598- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) 599- return result 600- 601- unaryProcessKernelSource = functools.partial( 602- processKernelSource, 603+ processKernelFn = functools.partial( 604+ processAndAssembleKernelTCL, 605 kernelWriterAssembly, 606 rocisa.rocIsa.getInstance().getData(), 607 splitGSU, 608+ assemblyTmpPath=assemblyTmpPath, 609+ assembler=asmToolchain.assembler 610 ) 611 612- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) 613- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) 614- ret = ParallelMap2( 615- compose(assemble, unaryWriteAssembly, unaryProcessKernelSource), 616+ results = ParallelMap2( 617+ processKernelFn, 618 uniqueAsmKernels, 619 "Generating assembly kernels", 620 multiArg=False, 621 return_as="list" 622 ) 623+ del processKernelFn 624+ gc.collect() 625+ 626 passPostKernelInfoToSolution( 627- ret, uniqueAsmKernels, solutions, splitGSU 628+ results, uniqueAsmKernels, solutions, splitGSU 629 ) 630- # result.src is very large so let garbage collector know to clean up 631- del ret 632+ del results 633+ gc.collect() 634+ 635 buildAssemblyCodeObjectFiles( 636 asmToolchain.linker, 637 asmToolchain.bundler, 638@@ -493,6 +557,15 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa 639 return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list 640 641 642+def libraryIter(lib: MasterSolutionLibrary): 643+ if len(lib.solutions): 644+ for i, s in enumerate(lib.solutions.items()): 645+ yield (i, *s) 646+ else: 647+ for _, lazyLib in lib.lazyLibraries.items(): 648+ yield from libraryIter(lazyLib) 649+ 650+ 651 @timing 652 def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap): 653 654@@ -508,26 +581,23 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf 655 printSolutionRejectionReason = True 656 printIndexAssignmentInfo = False 657 658- fIter = zip( 659- logicFiles, 660- itertools.repeat(assembler), 661- itertools.repeat(splitGSU), 662- itertools.repeat(printSolutionRejectionReason), 663- itertools.repeat(printIndexAssignmentInfo), 664- itertools.repeat(isaInfoMap), 665- itertools.repeat(args["LazyLibraryLoading"]), 666+ parseLogicFn = functools.partial( 667+ LibraryIO.parseLibraryLogicFile, 668+ assembler=assembler, 669+ splitGSU=splitGSU, 670+ printSolutionRejectionReason=printSolutionRejectionReason, 671+ printIndexAssignmentInfo=printIndexAssignmentInfo, 672+ isaInfoMap=isaInfoMap, 673+ lazyLibraryLoading=args["LazyLibraryLoading"] 674 ) 675 676- def libraryIter(lib: MasterSolutionLibrary): 677- if len(lib.solutions): 678- for i, s in enumerate(lib.solutions.items()): 679- yield (i, *s) 680- else: 681- for _, lazyLib in lib.lazyLibraries.items(): 682- yield from libraryIter(lazyLib) 683- 684 for library in ParallelMap2( 685- LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered" 686+ parseLogicFn, logicFiles, "Loading Logics...", 687+ return_as="generator_unordered", 688+ minChunkSize=24, 689+ maxWorkers=32, 690+ maxtasksperchild=1, 691+ multiArg=False, 692 ): 693 _, architectureName, _, _, _, newLibrary = library 694 695@@ -539,6 +609,9 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf 696 else: 697 masterLibraries[architectureName] = newLibrary 698 masterLibraries[architectureName].version = args["CodeObjectVersion"] 699+ del library, newLibrary 700+ 701+ gc.collect() 702 703 # Sort masterLibraries to make global soln index values deterministic 704 solnReIndex = 0 705@@ -734,6 +807,9 @@ def run(): 706 ) 707 stop_wsk = timer() 708 print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}") 709+ numKernelHelperObjs = len(kernelHelperObjs) 710+ del kernelWriterAssembly, kernelHelperObjs 711+ gc.collect() 712 713 archs = [ # is this really different than the other archs above? 714 isaToGfx(arch) 715@@ -751,13 +827,10 @@ def run(): 716 if kName not in solDict: 717 solDict["%s"%kName] = kernel 718 719- def writeMsl(name, lib): 720- filename = os.path.join(newLibraryDir, name) 721- lib.applyNaming(splitGSU) 722- LibraryIO.write(filename, state(lib), arguments["LibraryFormat"]) 723- 724 filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping") 725 LibraryIO.write(filename, libraryMapping, "msgpack") 726+ del libraryMapping 727+ gc.collect() 728 729 start_msl = timer() 730 for archName, newMasterLibrary in masterLibraries.items(): 731@@ -774,12 +847,22 @@ def run(): 732 kName = getKeyNoInternalArgs(s.originalSolution, splitGSU) 733 s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"] 734 735- ParallelMap2(writeMsl, 736+ writeFn = functools.partial( 737+ writeMasterSolutionLibrary, 738+ newLibraryDir=newLibraryDir, 739+ splitGSU=splitGSU, 740+ libraryFormat=arguments["LibraryFormat"] 741+ ) 742+ 743+ ParallelMap2(writeFn, 744 newMasterLibrary.lazyLibraries.items(), 745 "Writing master solution libraries", 746+ multiArg=False, 747 return_as="list") 748 stop_msl = timer() 749 print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}") 750+ del masterLibraries, solutions, kernels, solDict 751+ gc.collect() 752 753 if not arguments["KeepBuildTmp"]: 754 buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp" 755@@ -796,8 +879,11 @@ def run(): 756 print("") 757 758 stop = timer() 759+ peak_memory_mb, current_memory_mb = getMemoryUsage() 760 761 print(f"Total time (s): {(stop-start):3.2f}") 762 print(f"Total kernels processed: {numKernels}") 763 print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}") 764- print(f"KernelHelperObjs: {len(kernelHelperObjs)}") 765+ print(f"KernelHelperObjs: {numKernelHelperObjs}") 766+ print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}") 767+ print(f"Current memory usage (MB): {current_memory_mb:,.1f}") 768diff --git a/tensilelite/Tensile/TensileMergeLibrary.py b/tensilelite/Tensile/TensileMergeLibrary.py 769index e33c617b6f..ba163e9918 100644 770--- a/tensilelite/Tensile/TensileMergeLibrary.py 771+++ b/tensilelite/Tensile/TensileMergeLibrary.py 772@@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff= 773 logicsFiles[origFile] = origFile 774 logicsFiles[incFile] = incFile 775 776- iters = zip(logicsFiles.keys()) 777- logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list") 778+ logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False) 779 logicsDict = {} 780 for i, _ in enumerate(logicsList): 781 logicsDict[logicsList[i][0]] = logicsList[i][1] 782diff --git a/tensilelite/Tensile/TensileUpdateLibrary.py b/tensilelite/Tensile/TensileUpdateLibrary.py 783index 5ff265d0ed..c1803a6349 100644 784--- a/tensilelite/Tensile/TensileUpdateLibrary.py 785+++ b/tensilelite/Tensile/TensileUpdateLibrary.py 786@@ -26,7 +26,7 @@ from . import LibraryIO 787 from .Tensile import addCommonArguments, argUpdatedGlobalParameters 788 789 from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \ 790- globalParameters, architectureMap, ensurePath, ParallelMap, __version__ 791+ globalParameters, architectureMap, ensurePath, ParallelMap2, __version__ 792 793 import argparse 794 import copy 795@@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs): 796 for logicFile in logicFiles: 797 print("# %s" % logicFile) 798 fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath)) 799- libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap) 800+ libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list") 801 802 803 def main(): 804diff --git a/tensilelite/Tensile/Toolchain/Assembly.py b/tensilelite/Tensile/Toolchain/Assembly.py 805index a8b91e8d62..265e1d532c 100644 806--- a/tensilelite/Tensile/Toolchain/Assembly.py 807+++ b/tensilelite/Tensile/Toolchain/Assembly.py 808@@ -30,7 +30,7 @@ import subprocess 809 from pathlib import Path 810 from typing import List, Union, NamedTuple 811 812-from Tensile.Common import print2 813+from Tensile.Common import print1, print2 814 from Tensile.Common.Architectures import isaToGfx 815 from ..SolutionStructs import Solution 816 817@@ -92,8 +92,26 @@ def buildAssemblyCodeObjectFiles( 818 if coName: 819 coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj))) 820 821+ # Build reference count map for .o files to handle shared object files 822+ # (.o files from kernels marked .duplicate in TensileCreateLibrary) 823+ objFileRefCount = collections.Counter() 824+ for coFileRaw, objFiles in coFileMap.items(): 825+ for objFile in objFiles: 826+ objFileRefCount[objFile] += 1 827+ 828+ sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1} 829+ if sharedObjFiles: 830+ print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:") 831+ 832 for coFileRaw, objFiles in coFileMap.items(): 833 linker(objFiles, str(coFileRaw)) 834+ 835+ # Delete .o files after linking once usage count reaches 0 836+ for objFile in objFiles: 837+ objFileRefCount[objFile] -= 1 838+ if objFileRefCount[objFile] == 0: 839+ Path(objFile).unlink() 840+ 841 coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo) 842 if compress: 843 bundler.compress(str(coFileRaw), str(coFile), gfx) 844diff --git a/tensilelite/Tensile/Toolchain/Component.py b/tensilelite/Tensile/Toolchain/Component.py 845index 67fa35e2d8..dde83af4c3 100644 846--- a/tensilelite/Tensile/Toolchain/Component.py 847+++ b/tensilelite/Tensile/Toolchain/Component.py 848@@ -355,6 +355,7 @@ class Linker(Component): 849 when invoking the linker, LLVM allows the provision of arguments via a "response file" 850 Reference: https://llvm.org/docs/CommandLine.html#response-files 851 """ 852+ # FIXME: this prevents threading as clang_args.txt is overwritten 853 with open(Path.cwd() / "clang_args.txt", "wt") as file: 854 file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths)) 855 return [*(self.default_args), "-o", destPath, "@clang_args.txt"] 856diff --git a/tensilelite/requirements.txt b/tensilelite/requirements.txt 857index 60c4c11445..5c8fd66a88 100644 858--- a/tensilelite/requirements.txt 859+++ b/tensilelite/requirements.txt 860@@ -2,8 +2,6 @@ dataclasses; python_version == '3.6' 861 packaging 862 pyyaml 863 msgpack 864-joblib>=1.4.0; python_version >= '3.8' 865-joblib>=1.1.1; python_version < '3.8' 866 simplejson 867 ujson 868 orjson