nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at r-updates 849 lines 34 kB view raw
1diff --git a/tensilelite/Tensile/Common/Parallel.py b/tensilelite/Tensile/Common/Parallel.py 2index 1a2bf9e119..f46100c7b8 100644 3--- a/tensilelite/Tensile/Common/Parallel.py 4+++ b/tensilelite/Tensile/Common/Parallel.py 5@@ -22,43 +22,58 @@ 6 # 7 ################################################################################ 8 9-import concurrent.futures 10-import itertools 11+import multiprocessing 12 import os 13+import re 14 import sys 15 import time 16- 17-from joblib import Parallel, delayed 18+from functools import partial 19+from typing import Any, Callable 20 21 from .Utilities import tqdm 22 23 24-def joblibParallelSupportsGenerator(): 25- import joblib 26- from packaging.version import Version 27+def get_inherited_job_limit() -> int: 28+ # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+) 29+ if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ: 30+ try: 31+ return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL']) 32+ except ValueError: 33+ pass 34 35- joblibVer = joblib.__version__ 36- return Version(joblibVer) >= Version("1.4.0") 37+ # 2. Parse MAKEFLAGS for -jN 38+ makeflags = os.environ.get('MAKEFLAGS', '') 39+ match = re.search(r'-j\s*(\d+)', makeflags) 40+ if match: 41+ return int(match.group(1)) 42 43+ return -1 44 45-def CPUThreadCount(enable=True): 46- from .GlobalParameters import globalParameters 47 48+def CPUThreadCount(enable=True): 49 if not enable: 50 return 1 51- else: 52+ from .GlobalParameters import globalParameters 53+ 54+ # Priority order: 55+ # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS) 56+ # 2. Explicit --jobs flag 57+ # 3. Auto-detect 58+ inherited_limit = get_inherited_job_limit() 59+ cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"] 60+ 61+ if cpuThreads < 1: 62 if os.name == "nt": 63- # Windows supports at most 61 workers because the scheduler uses 64- # WaitForMultipleObjects directly, which has the limit (the limit 65- # is actually 64, but some handles are needed for accounting). 66- cpu_count = min(os.cpu_count(), 61) 67+ cpuThreads = os.cpu_count() 68 else: 69- cpu_count = len(os.sched_getaffinity(0)) 70- cpuThreads = globalParameters["CpuThreads"] 71- if cpuThreads == -1: 72- return cpu_count 73+ cpuThreads = len(os.sched_getaffinity(0)) 74 75- return min(cpu_count, cpuThreads) 76+ if os.name == "nt": 77+ # Windows supports at most 61 workers because the scheduler uses 78+ # WaitForMultipleObjects directly, which has the limit (the limit 79+ # is actually 64, but some handles are needed for accounting). 80+ cpuThreads = min(cpuThreads, 61) 81+ return max(1, cpuThreads) 82 83 84 def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters): 85@@ -71,19 +86,22 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters): 86 return f(arg) 87 88 89-def apply_print_exception(item, *args): 90- # print(item, args) 91+def OverwriteGlobalParameters(newGlobalParameters): 92+ from . import GlobalParameters 93+ 94+ GlobalParameters.globalParameters.clear() 95+ GlobalParameters.globalParameters.update(newGlobalParameters) 96+ 97+ 98+def worker_function(args, function, multiArg): 99+ """Worker function that executes in the pool process.""" 100 try: 101- if len(args) > 0: 102- func = item 103- args = args[0] 104- return func(*args) 105+ if multiArg: 106+ return function(*args) 107 else: 108- func, item = item 109- return func(item) 110+ return function(args) 111 except Exception: 112 import traceback 113- 114 traceback.print_exc() 115 raise 116 finally: 117@@ -98,154 +116,121 @@ def OverwriteGlobalParameters(newGlobalParameters): 118 GlobalParameters.globalParameters.update(newGlobalParameters) 119 120 121-def ProcessingPool(enable=True, maxTasksPerChild=None): 122- import multiprocessing 123- import multiprocessing.dummy 124- 125- threadCount = CPUThreadCount() 126- 127- if (not enable) or threadCount <= 1: 128- return multiprocessing.dummy.Pool(1) 129- 130- if multiprocessing.get_start_method() == "spawn": 131- from . import GlobalParameters 132- 133- return multiprocessing.Pool( 134- threadCount, 135- initializer=OverwriteGlobalParameters, 136- maxtasksperchild=maxTasksPerChild, 137- initargs=(GlobalParameters.globalParameters,), 138- ) 139- else: 140- return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild) 141+def progress_logger(iterable, total, message, min_log_interval=5.0): 142+ """ 143+ Generator that wraps an iterable and logs progress with time-based throttling. 144 145+ Only logs progress if at least min_log_interval seconds have passed since last log. 146+ Only prints completion message if task took >= min_log_interval seconds. 147 148-def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None): 149+ Yields (index, item) tuples. 150 """ 151- Generally equivalent to list(map(function, objects)), possibly executing in parallel. 152- 153- message: A message describing the operation to be performed. 154- enable: May be set to false to disable parallelism. 155- method: A function which can fetch the mapping function from a processing pool object. 156- Leave blank to use .map(), other possiblities: 157- - `lambda x: x.starmap` - useful if `function` takes multiple parameters. 158- - `lambda x: x.imap` - lazy evaluation 159- - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value. 160- """ 161- from .GlobalParameters import globalParameters 162+ start_time = time.time() 163+ last_log_time = start_time 164+ log_interval = 1 + (total // 100) 165 166- threadCount = CPUThreadCount(enable) 167- pool = ProcessingPool(enable, maxTasksPerChild) 168- 169- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 170- # Provide a progress bar for single-threaded operation. 171- # This works for method=None, and for starmap. 172- mapFunc = map 173- if method is not None: 174- # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7. 175- # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded 176- # operation. 177- try: 178- mapFunc = method(itertools) 179- except NameError: 180- mapFunc = None 181- 182- if mapFunc is not None: 183- return list(mapFunc(function, tqdm(objects, message))) 184- 185- mapFunc = pool.map 186- if method: 187- mapFunc = method(pool) 188- 189- objects = zip(itertools.repeat(function), objects) 190- function = apply_print_exception 191- 192- countMessage = "" 193- try: 194- countMessage = " for {} tasks".format(len(objects)) 195- except TypeError: 196- pass 197+ for idx, item in enumerate(iterable): 198+ if idx % log_interval == 0: 199+ current_time = time.time() 200+ if (current_time - last_log_time) >= min_log_interval: 201+ print(f"{message}\t{idx+1: 5d}/{total: 5d}") 202+ last_log_time = current_time 203+ yield idx, item 204 205- if message != "": 206- message += ": " 207+ elapsed = time.time() - start_time 208+ final_idx = idx + 1 if 'idx' in locals() else 0 209 210- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) 211- sys.stdout.flush() 212- currentTime = time.time() 213- rv = mapFunc(function, objects) 214- totalTime = time.time() - currentTime 215- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) 216- sys.stdout.flush() 217- pool.close() 218- return rv 219+ if elapsed >= min_log_interval or last_log_time > start_time: 220+ print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}") 221 222 223-def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True): 224- from .GlobalParameters import globalParameters 225+def imap_with_progress(pool, func, iterable, total, message, chunksize): 226+ results = [] 227+ for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message): 228+ results.append(result) 229+ return results 230 231- threadCount = CPUThreadCount(enable) 232- print("{0}Launching {1} threads...".format(message, threadCount)) 233 234- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 235- # Provide a progress bar for single-threaded operation. 236- callFunc = lambda args: function(*args) if multiArg else lambda args: function(args) 237- return [callFunc(args) for args in tqdm(objects, message)] 238+def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild): 239+ # separate fn because yield makes the entire fn a generator even if unreachable 240+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') 241 242- with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor: 243- resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects) 244- for result in concurrent.futures.as_completed(resultFutures): 245- yield result.result() 246+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, 247+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: 248+ for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message): 249+ yield result 250 251 252 def ParallelMap2( 253- function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None 254+ function: Callable, 255+ objects: Any, 256+ message: str = "", 257+ enable: bool = True, 258+ multiArg: bool = True, 259+ minChunkSize: int = 1, 260+ maxWorkers: int = -1, 261+ maxtasksperchild: int = 1024, 262+ return_as: str = "list" 263 ): 264+ """Executes a function over a list of objects in parallel or sequentially. 265+ 266+ This function is generally equivalent to ``list(map(function, objects))``. However, it provides 267+ additional functionality to run in parallel, depending on the 'enable' flag and available CPU 268+ threads. 269+ 270+ Args: 271+ function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function' 272+ should accept multiple arguments. 273+ objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each 274+ item in 'objects' should be an iterable of arguments for 'function'. 275+ message: Optional; a message describing the operation. Default is an empty string. 276+ enable: Optional; if False, disables parallel execution and runs sequentially. Default is True. 277+ multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for 278+ 'function'. Default is True. 279+ return_as: Optional; "list" (default) or "generator_unordered" for streaming results 280+ 281+ Returns: 282+ A list or generator containing the results of applying **function** to each item in **objects**. 283 """ 284- Generally equivalent to list(map(function, objects)), possibly executing in parallel. 285+ from .GlobalParameters import globalParameters 286 287- message: A message describing the operation to be performed. 288- enable: May be set to false to disable parallelism. 289- multiArg: True if objects represent multiple arguments 290- (differentiates multi args vs single collection arg) 291- """ 292- if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator(): 293- return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg) 294+ threadCount = CPUThreadCount(enable) 295 296- from .GlobalParameters import globalParameters 297+ if not hasattr(objects, "__len__"): 298+ objects = list(objects) 299 300- threadCount = procs if procs else CPUThreadCount(enable) 301+ objLen = len(objects) 302+ if objLen == 0: 303+ return [] if return_as == "list" else iter([]) 304 305- threadCount = CPUThreadCount(enable) 306+ f = (lambda x: function(*x)) if multiArg else function 307+ if objLen == 1: 308+ print(f"{message}: (1 task)") 309+ result = [f(x) for x in objects] 310+ return result if return_as == "list" else iter(result) 311 312- if threadCount <= 1 and globalParameters["ShowProgressBar"]: 313- # Provide a progress bar for single-threaded operation. 314- return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)] 315+ extra_message = ( 316+ f": {threadCount} thread(s)" + f", {objLen} tasks" 317+ if objLen 318+ else "" 319+ ) 320 321- countMessage = "" 322- try: 323- countMessage = " for {} tasks".format(len(objects)) 324- except TypeError: 325- pass 326- 327- if message != "": 328- message += ": " 329- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) 330- sys.stdout.flush() 331- currentTime = time.time() 332- 333- pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg 334- pargs = zip(objects, itertools.repeat(globalParameters)) 335- 336- if joblibParallelSupportsGenerator(): 337- rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)( 338- delayed(pcall)(function, a, params) for a, params in pargs 339- ) 340+ print(f"ParallelMap {message}{extra_message}") 341+ 342+ if threadCount <= 1: 343+ result = [f(x) for x in objects] 344+ return result if return_as == "list" else iter(result) 345+ 346+ if maxWorkers > 0: 347+ threadCount = min(maxWorkers, threadCount) 348+ 349+ chunksize = max(minChunkSize, objLen // 2000) 350+ worker = partial(worker_function, function=function, multiArg=multiArg) 351+ if return_as == "generator_unordered": 352+ # yield results as they complete without buffering 353+ return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild) 354 else: 355- rv = Parallel(n_jobs=threadCount, timeout=99999)( 356- delayed(pcall)(function, a, params) for a, params in pargs 357- ) 358- 359- totalTime = time.time() - currentTime 360- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) 361- sys.stdout.flush() 362- return rv 363+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') 364+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, 365+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: 366+ return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize)) 367diff --git a/tensilelite/Tensile/CustomKernels.py b/tensilelite/Tensile/CustomKernels.py 368index ffceb636f5..127b3386a1 100644 369--- a/tensilelite/Tensile/CustomKernels.py 370+++ b/tensilelite/Tensile/CustomKernels.py 371@@ -24,7 +24,9 @@ 372 373 from . import CUSTOM_KERNEL_PATH 374 from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters 375+from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER 376 377+from functools import lru_cache 378 import yaml 379 380 import os 381@@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH): 382 383 return (config, assembly) 384 385+# getCustomKernelConfig will get called repeatedly on the same file 386+# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs 387+@lru_cache 388 def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH): 389 rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory) 390 try: 391- return yaml.safe_load(rawConfig)["custom.config"] 392+ return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"] 393 except yaml.scanner.ScannerError as e: 394 raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e)) 395 396diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py 397index e03f456fbe..ed7510f2ce 100644 398--- a/tensilelite/Tensile/CustomYamlLoader.py 399+++ b/tensilelite/Tensile/CustomYamlLoader.py 400@@ -1,6 +1,7 @@ 401 # Copyright © Advanced Micro Devices, Inc., or its affiliates. 402 # SPDX-License-Identifier: MIT 403 404+import sys 405 import yaml 406 from pathlib import Path 407 408@@ -70,7 +71,7 @@ def parse_scalar(loader: yaml.Loader): 409 elif is_float(value_lower): 410 return float(value_lower) 411 412- return value 413+ return sys.intern(value) 414 415 def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader): 416 with open(yaml_path, 'r') as f: 417diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py 418index 22d19851a3..348068b3bf 100644 419--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py 420+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py 421@@ -26,8 +26,10 @@ import rocisa 422 423 import functools 424 import glob 425+import gc 426 import itertools 427 import os 428+import resource 429 import shutil 430 from pathlib import Path 431 from timeit import default_timer as timer 432@@ -78,6 +80,25 @@ from Tensile.Utilities.Decorators.Timing import timing 433 from .ParseArguments import parseArguments 434 435 436+def getMemoryUsage(): 437+ """Get peak and current memory usage in MB.""" 438+ rusage = resource.getrusage(resource.RUSAGE_SELF) 439+ peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux 440+ 441+ # Get current memory from /proc/self/status 442+ current_memory_mb = 0 443+ try: 444+ with open('/proc/self/status') as f: 445+ for line in f: 446+ if line.startswith('VmRSS:'): 447+ current_memory_mb = int(line.split()[1]) / 1024 # KB to MB 448+ break 449+ except: 450+ current_memory_mb = peak_memory_mb # Fallback 451+ 452+ return (peak_memory_mb, current_memory_mb) 453+ 454+ 455 class KernelCodeGenResult(NamedTuple): 456 err: int 457 src: str 458@@ -115,6 +136,29 @@ def processKernelSource(kernelWriterAssembly, data, outOptions, splitGSU, kernel 459 ) 460 461 462+def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, outOptions, splitGSU, kernel, assemblyTmpPath, assembler): 463+ """ 464+ Pipeline function for TCL mode that: 465+ 1. Generates kernel source 466+ 2. Writes .s file to disk 467+ 3. Assembles to .o file 468+ 4. Deletes .s file 469+ """ 470+ result = processKernelSource(kernelWriterAssembly, rocisa_data, outOptions, splitGSU, kernel) 471+ return writeAndAssembleKernel(result, assemblyTmpPath, assembler) 472+ 473+ 474+def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat): 475+ """ 476+ Write a master solution library to disk. 477+ Module-level function to support multiprocessing. 478+ """ 479+ name, lib = name_lib_tuple 480+ filename = os.path.join(newLibraryDir, name) 481+ lib.applyNaming(splitGSU) 482+ LibraryIO.write(filename, state(lib), libraryFormat) 483+ 484+ 485 def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool): 486 removeKernels = [] 487 removeKernelNames = [] 488@@ -189,6 +233,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult): 489 return path, isa, wfsize, minResult 490 491 492+def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler): 493+ """Write assembly file and immediately assemble it to .o file""" 494+ if result.err: 495+ printExit(f"Failed to build kernel {result.name} because it has error code {result.err}") 496+ 497+ path = Path(asmPath) / f"{result.name}.s" 498+ with open(path, "w", encoding="utf-8") as f: 499+ f.write(result.src) 500+ 501+ # Assemble .s -> .o 502+ assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o"))) 503+ 504+ # Delete assembly file immediately to save disk space 505+ path.unlink() 506+ 507+ return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk) 508+ 509+ 510 def writeHelpers( 511 outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H 512 ): 513@@ -272,14 +334,15 @@ def writeSolutionsAndKernels( 514 numAsmKernels = len(asmKernels) 515 numKernels = len(asmKernels) 516 assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite" 517- asmIter = zip( 518- itertools.repeat(kernelWriterAssembly), 519- itertools.repeat(rocisa.rocIsa.getInstance().getData()), 520- itertools.repeat(outOptions), 521- itertools.repeat(splitGSU), 522- asmKernels 523+ 524+ processKernelFn = functools.partial( 525+ processKernelSource, 526+ kernelWriterAssembly=kernelWriterAssembly, 527+ data=rocisa.rocIsa.getInstance().getData(), 528+ outOptions=outOptions, 529+ splitGSU=splitGSU 530 ) 531- asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list") 532+ asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False) 533 removeInvalidSolutionsAndKernels( 534 asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU 535 ) 536@@ -287,19 +350,21 @@ def writeSolutionsAndKernels( 537 asmResults, asmKernels, solutions, splitGSU 538 ) 539 540- def assemble(ret): 541- p, isa, wavefrontsize, _ = ret 542- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) 543- 544- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) 545- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) 546+ # Use functools.partial to bind assemblyTmpPath and assembler 547+ writeAndAssembleFn = functools.partial( 548+ writeAndAssembleKernel, 549+ asmPath=assemblyTmpPath, 550+ assembler=asmToolchain.assembler 551+ ) 552 ret = ParallelMap2( 553- compose(assemble, unaryWriteAssembly), 554+ writeAndAssembleFn, 555 asmResults, 556 "Writing assembly kernels", 557 return_as="list", 558 multiArg=False, 559 ) 560+ del asmResults 561+ gc.collect() 562 563 writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H) 564 srcKernelFile = Path(outputPath) / "Kernels.cpp" 565@@ -376,40 +441,35 @@ def writeSolutionsAndKernelsTCL( 566 567 uniqueAsmKernels = [k for k in asmKernels if not k.duplicate] 568 569- def assemble(ret, removeTemporaries: bool): 570- asmPath, isa, wavefrontsize, result = ret 571- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(asmPath), str(asmPath.with_suffix(".o"))) 572- if removeTemporaries: 573- asmPath.unlink() 574- return result 575- 576- unaryAssemble = functools.partial(assemble, removeTemporaries=removeTemporaries) 577- 578 outOptions = rocisa.rocIsa.getInstance().getOutputOptions() 579 outOptions.outputNoComment = not disableAsmComments 580 581- unaryProcessKernelSource = functools.partial( 582- processKernelSource, 583+ processKernelFn = functools.partial( 584+ processAndAssembleKernelTCL, 585 kernelWriterAssembly, 586 rocisa.rocIsa.getInstance().getData(), 587 outOptions, 588 splitGSU, 589+ assemblyTmpPath=assemblyTmpPath, 590+ assembler=asmToolchain.assembler 591 ) 592 593- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) 594- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) 595- ret = ParallelMap2( 596- compose(unaryAssemble, unaryWriteAssembly, unaryProcessKernelSource), 597+ results = ParallelMap2( 598+ processKernelFn, 599 uniqueAsmKernels, 600 "Generating assembly kernels", 601 multiArg=False, 602 return_as="list" 603 ) 604+ del processKernelFn 605+ gc.collect() 606+ 607 passPostKernelInfoToSolution( 608- ret, uniqueAsmKernels, solutions, splitGSU 609+ results, uniqueAsmKernels, solutions, splitGSU 610 ) 611- # result.src is very large so let garbage collector know to clean up 612- del ret 613+ del results 614+ gc.collect() 615+ 616 buildAssemblyCodeObjectFiles( 617 asmToolchain.linker, 618 asmToolchain.bundler, 619@@ -508,6 +568,15 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa 620 return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list 621 622 623+def libraryIter(lib: MasterSolutionLibrary): 624+ if len(lib.solutions): 625+ for i, s in enumerate(lib.solutions.items()): 626+ yield (i, *s) 627+ else: 628+ for _, lazyLib in lib.lazyLibraries.items(): 629+ yield from libraryIter(lazyLib) 630+ 631+ 632 @timing 633 def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap): 634 635@@ -523,26 +592,23 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf 636 printSolutionRejectionReason = True 637 printIndexAssignmentInfo = False 638 639- fIter = zip( 640- logicFiles, 641- itertools.repeat(assembler), 642- itertools.repeat(splitGSU), 643- itertools.repeat(printSolutionRejectionReason), 644- itertools.repeat(printIndexAssignmentInfo), 645- itertools.repeat(isaInfoMap), 646- itertools.repeat(args["LazyLibraryLoading"]), 647+ parseLogicFn = functools.partial( 648+ LibraryIO.parseLibraryLogicFile, 649+ assembler=assembler, 650+ splitGSU=splitGSU, 651+ printSolutionRejectionReason=printSolutionRejectionReason, 652+ printIndexAssignmentInfo=printIndexAssignmentInfo, 653+ isaInfoMap=isaInfoMap, 654+ lazyLibraryLoading=args["LazyLibraryLoading"] 655 ) 656 657- def libraryIter(lib: MasterSolutionLibrary): 658- if len(lib.solutions): 659- for i, s in enumerate(lib.solutions.items()): 660- yield (i, *s) 661- else: 662- for _, lazyLib in lib.lazyLibraries.items(): 663- yield from libraryIter(lazyLib) 664- 665 for library in ParallelMap2( 666- LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered" 667+ parseLogicFn, logicFiles, "Loading Logics...", 668+ return_as="generator_unordered", 669+ minChunkSize=24, 670+ maxWorkers=32, 671+ maxtasksperchild=1, 672+ multiArg=False, 673 ): 674 _, architectureName, _, _, _, newLibrary = library 675 676@@ -554,6 +620,9 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf 677 else: 678 masterLibraries[architectureName] = newLibrary 679 masterLibraries[architectureName].version = args["CodeObjectVersion"] 680+ del library, newLibrary 681+ 682+ gc.collect() 683 684 # Sort masterLibraries to make global soln index values deterministic 685 solnReIndex = 0 686@@ -751,6 +820,9 @@ def run(): 687 ) 688 stop_wsk = timer() 689 print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}") 690+ numKernelHelperObjs = len(kernelHelperObjs) 691+ del kernelWriterAssembly, kernelHelperObjs 692+ gc.collect() 693 694 archs = [ # is this really different than the other archs above? 695 isaToGfx(arch) 696@@ -768,13 +840,10 @@ def run(): 697 if kName not in solDict: 698 solDict["%s"%kName] = kernel 699 700- def writeMsl(name, lib): 701- filename = os.path.join(newLibraryDir, name) 702- lib.applyNaming(splitGSU) 703- LibraryIO.write(filename, state(lib), arguments["LibraryFormat"]) 704- 705 filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping") 706 LibraryIO.write(filename, libraryMapping, "msgpack") 707+ del libraryMapping 708+ gc.collect() 709 710 start_msl = timer() 711 for archName, newMasterLibrary in masterLibraries.items(): 712@@ -791,12 +860,22 @@ def run(): 713 kName = getKeyNoInternalArgs(s.originalSolution, splitGSU) 714 s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"] 715 716- ParallelMap2(writeMsl, 717+ writeFn = functools.partial( 718+ writeMasterSolutionLibrary, 719+ newLibraryDir=newLibraryDir, 720+ splitGSU=splitGSU, 721+ libraryFormat=arguments["LibraryFormat"] 722+ ) 723+ 724+ ParallelMap2(writeFn, 725 newMasterLibrary.lazyLibraries.items(), 726 "Writing master solution libraries", 727+ multiArg=False, 728 return_as="list") 729 stop_msl = timer() 730 print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}") 731+ del masterLibraries, solutions, kernels, solDict 732+ gc.collect() 733 734 if not arguments["KeepBuildTmp"]: 735 buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp" 736@@ -813,8 +892,11 @@ def run(): 737 print("") 738 739 stop = timer() 740+ peak_memory_mb, current_memory_mb = getMemoryUsage() 741 742 print(f"Total time (s): {(stop-start):3.2f}") 743 print(f"Total kernels processed: {numKernels}") 744 print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}") 745- print(f"KernelHelperObjs: {len(kernelHelperObjs)}") 746+ print(f"KernelHelperObjs: {numKernelHelperObjs}") 747+ print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}") 748+ print(f"Current memory usage (MB): {current_memory_mb:,.1f}") 749diff --git a/tensilelite/Tensile/TensileMergeLibrary.py b/tensilelite/Tensile/TensileMergeLibrary.py 750index e33c617b6f..ba163e9918 100644 751--- a/tensilelite/Tensile/TensileMergeLibrary.py 752+++ b/tensilelite/Tensile/TensileMergeLibrary.py 753@@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff= 754 logicsFiles[origFile] = origFile 755 logicsFiles[incFile] = incFile 756 757- iters = zip(logicsFiles.keys()) 758- logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list") 759+ logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False) 760 logicsDict = {} 761 for i, _ in enumerate(logicsList): 762 logicsDict[logicsList[i][0]] = logicsList[i][1] 763diff --git a/tensilelite/Tensile/TensileUpdateLibrary.py b/tensilelite/Tensile/TensileUpdateLibrary.py 764index 5ff265d0ed..c1803a6349 100644 765--- a/tensilelite/Tensile/TensileUpdateLibrary.py 766+++ b/tensilelite/Tensile/TensileUpdateLibrary.py 767@@ -26,7 +26,7 @@ from . import LibraryIO 768 from .Tensile import addCommonArguments, argUpdatedGlobalParameters 769 770 from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \ 771- globalParameters, architectureMap, ensurePath, ParallelMap, __version__ 772+ globalParameters, architectureMap, ensurePath, ParallelMap2, __version__ 773 774 import argparse 775 import copy 776@@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs): 777 for logicFile in logicFiles: 778 print("# %s" % logicFile) 779 fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath)) 780- libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap) 781+ libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list") 782 783 784 def main(): 785diff --git a/tensilelite/Tensile/Toolchain/Assembly.py b/tensilelite/Tensile/Toolchain/Assembly.py 786index a8b91e8d62..265e1d532c 100644 787--- a/tensilelite/Tensile/Toolchain/Assembly.py 788+++ b/tensilelite/Tensile/Toolchain/Assembly.py 789@@ -30,7 +30,7 @@ import subprocess 790 from pathlib import Path 791 from typing import List, Union, NamedTuple 792 793-from Tensile.Common import print2 794+from Tensile.Common import print1, print2 795 from Tensile.Common.Architectures import isaToGfx 796 from ..SolutionStructs import Solution 797 798@@ -92,8 +92,26 @@ def buildAssemblyCodeObjectFiles( 799 if coName: 800 coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj))) 801 802+ # Build reference count map for .o files to handle shared object files 803+ # (.o files from kernels marked .duplicate in TensileCreateLibrary) 804+ objFileRefCount = collections.Counter() 805+ for coFileRaw, objFiles in coFileMap.items(): 806+ for objFile in objFiles: 807+ objFileRefCount[objFile] += 1 808+ 809+ sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1} 810+ if sharedObjFiles: 811+ print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:") 812+ 813 for coFileRaw, objFiles in coFileMap.items(): 814 linker(objFiles, str(coFileRaw)) 815+ 816+ # Delete .o files after linking once usage count reaches 0 817+ for objFile in objFiles: 818+ objFileRefCount[objFile] -= 1 819+ if objFileRefCount[objFile] == 0: 820+ Path(objFile).unlink() 821+ 822 coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo) 823 if compress: 824 bundler.compress(str(coFileRaw), str(coFile), gfx) 825diff --git a/tensilelite/Tensile/Toolchain/Component.py b/tensilelite/Tensile/Toolchain/Component.py 826index 67fa35e2d8..dde83af4c3 100644 827--- a/tensilelite/Tensile/Toolchain/Component.py 828+++ b/tensilelite/Tensile/Toolchain/Component.py 829@@ -355,6 +355,7 @@ class Linker(Component): 830 when invoking the linker, LLVM allows the provision of arguments via a "response file" 831 Reference: https://llvm.org/docs/CommandLine.html#response-files 832 """ 833+ # FIXME: this prevents threading as clang_args.txt is overwritten 834 with open(Path.cwd() / "clang_args.txt", "wt") as file: 835 file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths)) 836 return [*(self.default_args), "-o", destPath, "@clang_args.txt"] 837diff --git a/tensilelite/requirements.txt b/tensilelite/requirements.txt 838index 60c4c11445..5c8fd66a88 100644 839--- a/tensilelite/requirements.txt 840+++ b/tensilelite/requirements.txt 841@@ -2,8 +2,6 @@ dataclasses; python_version == '3.6' 842 packaging 843 pyyaml 844 msgpack 845-joblib>=1.4.0; python_version >= '3.8' 846-joblib>=1.1.1; python_version < '3.8' 847 simplejson 848 ujson 849 orjson