nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1diff --git a/tensilelite/Tensile/SolutionStructs/Naming.py b/tensilelite/Tensile/SolutionStructs/Naming.py
2index 4f220960db1d..99535e246650 100644
3--- a/tensilelite/Tensile/SolutionStructs/Naming.py
4+++ b/tensilelite/Tensile/SolutionStructs/Naming.py
5@@ -105,7 +105,6 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna
6 if splitGSU:
7 state["GlobalSplitU"] = "M" if (state["GlobalSplitU"] > 1 or state["GlobalSplitU"] == -1) else state["GlobalSplitU"]
8
9-
10 requiredParametersTemp = set(requiredParameters.union(["GlobalSplitU"]))
11
12 if ignoreInternalArgs:
13diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py
14index bab8c687509..e03f456fbec 100644
15--- a/tensilelite/Tensile/CustomYamlLoader.py
16+++ b/tensilelite/Tensile/CustomYamlLoader.py
17@@ -1,3 +1,6 @@
18+# Copyright © Advanced Micro Devices, Inc., or its affiliates.
19+# SPDX-License-Identifier: MIT
20+
21 import yaml
22 from pathlib import Path
23
24Author: Luna Nova <git@lunnova.dev>
25Date: Sun Oct 12 11:52:10 2025 -0700
26
27 [hipblaslt] intern strings to reduce duplicate memory for solution keys
28
29diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py
30index 685e69220c..9fdf38d8e5 100644
31--- a/tensilelite/Tensile/CustomYamlLoader.py
32+++ b/tensilelite/Tensile/CustomYamlLoader.py
33@@ -1,6 +1,7 @@
34 # Copyright © Advanced Micro Devices, Inc., or its affiliates.
35 # SPDX-License-Identifier: MIT
36
37+import sys
38 import yaml
39 from pathlib import Path
40
41@@ -85,7 +86,7 @@ def parse_scalar(loader: yaml.Loader):
42 if not evt.style:
43 return None
44
45- return value
46+ return sys.intern(value)
47
48 def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader):
49 with open(yaml_path, 'r') as f:
50
51diff --git a/tensilelite/Tensile/Common/Parallel.py b/tensilelite/Tensile/Common/Parallel.py
52index 1a2bf9e119..f46100c7b8 100644
53--- a/tensilelite/Tensile/Common/Parallel.py
54+++ b/tensilelite/Tensile/Common/Parallel.py
55@@ -22,43 +22,58 @@
56 #
57 ################################################################################
58
59-import concurrent.futures
60-import itertools
61+import multiprocessing
62 import os
63+import re
64 import sys
65 import time
66-
67-from joblib import Parallel, delayed
68+from functools import partial
69+from typing import Any, Callable
70
71 from .Utilities import tqdm
72
73
74-def joblibParallelSupportsGenerator():
75- import joblib
76- from packaging.version import Version
77+def get_inherited_job_limit() -> int:
78+ # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+)
79+ if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ:
80+ try:
81+ return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL'])
82+ except ValueError:
83+ pass
84
85- joblibVer = joblib.__version__
86- return Version(joblibVer) >= Version("1.4.0")
87+ # 2. Parse MAKEFLAGS for -jN
88+ makeflags = os.environ.get('MAKEFLAGS', '')
89+ match = re.search(r'-j\s*(\d+)', makeflags)
90+ if match:
91+ return int(match.group(1))
92
93+ return -1
94
95-def CPUThreadCount(enable=True):
96- from .GlobalParameters import globalParameters
97
98+def CPUThreadCount(enable=True):
99 if not enable:
100 return 1
101- else:
102+ from .GlobalParameters import globalParameters
103+
104+ # Priority order:
105+ # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS)
106+ # 2. Explicit --jobs flag
107+ # 3. Auto-detect
108+ inherited_limit = get_inherited_job_limit()
109+ cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"]
110+
111+ if cpuThreads < 1:
112 if os.name == "nt":
113- # Windows supports at most 61 workers because the scheduler uses
114- # WaitForMultipleObjects directly, which has the limit (the limit
115- # is actually 64, but some handles are needed for accounting).
116- cpu_count = min(os.cpu_count(), 61)
117+ cpuThreads = os.cpu_count()
118 else:
119- cpu_count = len(os.sched_getaffinity(0))
120- cpuThreads = globalParameters["CpuThreads"]
121- if cpuThreads == -1:
122- return cpu_count
123+ cpuThreads = len(os.sched_getaffinity(0))
124
125- return min(cpu_count, cpuThreads)
126+ if os.name == "nt":
127+ # Windows supports at most 61 workers because the scheduler uses
128+ # WaitForMultipleObjects directly, which has the limit (the limit
129+ # is actually 64, but some handles are needed for accounting).
130+ cpuThreads = min(cpuThreads, 61)
131+ return max(1, cpuThreads)
132
133
134 def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters):
135@@ -71,19 +86,22 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters):
136 return f(arg)
137
138
139-def apply_print_exception(item, *args):
140- # print(item, args)
141+def OverwriteGlobalParameters(newGlobalParameters):
142+ from . import GlobalParameters
143+
144+ GlobalParameters.globalParameters.clear()
145+ GlobalParameters.globalParameters.update(newGlobalParameters)
146+
147+
148+def worker_function(args, function, multiArg):
149+ """Worker function that executes in the pool process."""
150 try:
151- if len(args) > 0:
152- func = item
153- args = args[0]
154- return func(*args)
155+ if multiArg:
156+ return function(*args)
157 else:
158- func, item = item
159- return func(item)
160+ return function(args)
161 except Exception:
162 import traceback
163-
164 traceback.print_exc()
165 raise
166 finally:
167@@ -98,154 +116,121 @@ def OverwriteGlobalParameters(newGlobalParameters):
168 GlobalParameters.globalParameters.update(newGlobalParameters)
169
170
171-def ProcessingPool(enable=True, maxTasksPerChild=None):
172- import multiprocessing
173- import multiprocessing.dummy
174-
175- threadCount = CPUThreadCount()
176-
177- if (not enable) or threadCount <= 1:
178- return multiprocessing.dummy.Pool(1)
179-
180- if multiprocessing.get_start_method() == "spawn":
181- from . import GlobalParameters
182-
183- return multiprocessing.Pool(
184- threadCount,
185- initializer=OverwriteGlobalParameters,
186- maxtasksperchild=maxTasksPerChild,
187- initargs=(GlobalParameters.globalParameters,),
188- )
189- else:
190- return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild)
191+def progress_logger(iterable, total, message, min_log_interval=5.0):
192+ """
193+ Generator that wraps an iterable and logs progress with time-based throttling.
194
195+ Only logs progress if at least min_log_interval seconds have passed since last log.
196+ Only prints completion message if task took >= min_log_interval seconds.
197
198-def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None):
199+ Yields (index, item) tuples.
200 """
201- Generally equivalent to list(map(function, objects)), possibly executing in parallel.
202-
203- message: A message describing the operation to be performed.
204- enable: May be set to false to disable parallelism.
205- method: A function which can fetch the mapping function from a processing pool object.
206- Leave blank to use .map(), other possiblities:
207- - `lambda x: x.starmap` - useful if `function` takes multiple parameters.
208- - `lambda x: x.imap` - lazy evaluation
209- - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value.
210- """
211- from .GlobalParameters import globalParameters
212+ start_time = time.time()
213+ last_log_time = start_time
214+ log_interval = 1 + (total // 100)
215
216- threadCount = CPUThreadCount(enable)
217- pool = ProcessingPool(enable, maxTasksPerChild)
218-
219- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
220- # Provide a progress bar for single-threaded operation.
221- # This works for method=None, and for starmap.
222- mapFunc = map
223- if method is not None:
224- # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7.
225- # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded
226- # operation.
227- try:
228- mapFunc = method(itertools)
229- except NameError:
230- mapFunc = None
231-
232- if mapFunc is not None:
233- return list(mapFunc(function, tqdm(objects, message)))
234-
235- mapFunc = pool.map
236- if method:
237- mapFunc = method(pool)
238-
239- objects = zip(itertools.repeat(function), objects)
240- function = apply_print_exception
241-
242- countMessage = ""
243- try:
244- countMessage = " for {} tasks".format(len(objects))
245- except TypeError:
246- pass
247+ for idx, item in enumerate(iterable):
248+ if idx % log_interval == 0:
249+ current_time = time.time()
250+ if (current_time - last_log_time) >= min_log_interval:
251+ print(f"{message}\t{idx+1: 5d}/{total: 5d}")
252+ last_log_time = current_time
253+ yield idx, item
254
255- if message != "":
256- message += ": "
257+ elapsed = time.time() - start_time
258+ final_idx = idx + 1 if 'idx' in locals() else 0
259
260- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
261- sys.stdout.flush()
262- currentTime = time.time()
263- rv = mapFunc(function, objects)
264- totalTime = time.time() - currentTime
265- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
266- sys.stdout.flush()
267- pool.close()
268- return rv
269+ if elapsed >= min_log_interval or last_log_time > start_time:
270+ print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}")
271
272
273-def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True):
274- from .GlobalParameters import globalParameters
275+def imap_with_progress(pool, func, iterable, total, message, chunksize):
276+ results = []
277+ for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message):
278+ results.append(result)
279+ return results
280
281- threadCount = CPUThreadCount(enable)
282- print("{0}Launching {1} threads...".format(message, threadCount))
283
284- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
285- # Provide a progress bar for single-threaded operation.
286- callFunc = lambda args: function(*args) if multiArg else lambda args: function(args)
287- return [callFunc(args) for args in tqdm(objects, message)]
288+def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild):
289+ # separate fn because yield makes the entire fn a generator even if unreachable
290+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')
291
292- with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor:
293- resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects)
294- for result in concurrent.futures.as_completed(resultFutures):
295- yield result.result()
296+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
297+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
298+ for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message):
299+ yield result
300
301
302 def ParallelMap2(
303- function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None
304+ function: Callable,
305+ objects: Any,
306+ message: str = "",
307+ enable: bool = True,
308+ multiArg: bool = True,
309+ minChunkSize: int = 1,
310+ maxWorkers: int = -1,
311+ maxtasksperchild: int = 1024,
312+ return_as: str = "list"
313 ):
314+ """Executes a function over a list of objects in parallel or sequentially.
315+
316+ This function is generally equivalent to ``list(map(function, objects))``. However, it provides
317+ additional functionality to run in parallel, depending on the 'enable' flag and available CPU
318+ threads.
319+
320+ Args:
321+ function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function'
322+ should accept multiple arguments.
323+ objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each
324+ item in 'objects' should be an iterable of arguments for 'function'.
325+ message: Optional; a message describing the operation. Default is an empty string.
326+ enable: Optional; if False, disables parallel execution and runs sequentially. Default is True.
327+ multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for
328+ 'function'. Default is True.
329+ return_as: Optional; "list" (default) or "generator_unordered" for streaming results
330+
331+ Returns:
332+ A list or generator containing the results of applying **function** to each item in **objects**.
333 """
334- Generally equivalent to list(map(function, objects)), possibly executing in parallel.
335+ from .GlobalParameters import globalParameters
336
337- message: A message describing the operation to be performed.
338- enable: May be set to false to disable parallelism.
339- multiArg: True if objects represent multiple arguments
340- (differentiates multi args vs single collection arg)
341- """
342- if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator():
343- return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg)
344+ threadCount = CPUThreadCount(enable)
345
346- from .GlobalParameters import globalParameters
347+ if not hasattr(objects, "__len__"):
348+ objects = list(objects)
349
350- threadCount = procs if procs else CPUThreadCount(enable)
351+ objLen = len(objects)
352+ if objLen == 0:
353+ return [] if return_as == "list" else iter([])
354
355- threadCount = CPUThreadCount(enable)
356+ f = (lambda x: function(*x)) if multiArg else function
357+ if objLen == 1:
358+ print(f"{message}: (1 task)")
359+ result = [f(x) for x in objects]
360+ return result if return_as == "list" else iter(result)
361
362- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
363- # Provide a progress bar for single-threaded operation.
364- return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)]
365+ extra_message = (
366+ f": {threadCount} thread(s)" + f", {objLen} tasks"
367+ if objLen
368+ else ""
369+ )
370
371- countMessage = ""
372- try:
373- countMessage = " for {} tasks".format(len(objects))
374- except TypeError:
375- pass
376-
377- if message != "":
378- message += ": "
379- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
380- sys.stdout.flush()
381- currentTime = time.time()
382-
383- pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg
384- pargs = zip(objects, itertools.repeat(globalParameters))
385-
386- if joblibParallelSupportsGenerator():
387- rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)(
388- delayed(pcall)(function, a, params) for a, params in pargs
389- )
390+ print(f"ParallelMap {message}{extra_message}")
391+
392+ if threadCount <= 1:
393+ result = [f(x) for x in objects]
394+ return result if return_as == "list" else iter(result)
395+
396+ if maxWorkers > 0:
397+ threadCount = min(maxWorkers, threadCount)
398+
399+ chunksize = max(minChunkSize, objLen // 2000)
400+ worker = partial(worker_function, function=function, multiArg=multiArg)
401+ if return_as == "generator_unordered":
402+ # yield results as they complete without buffering
403+ return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild)
404 else:
405- rv = Parallel(n_jobs=threadCount, timeout=99999)(
406- delayed(pcall)(function, a, params) for a, params in pargs
407- )
408-
409- totalTime = time.time() - currentTime
410- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
411- sys.stdout.flush()
412- return rv
413+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')
414+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
415+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
416+ return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize))
417diff --git a/tensilelite/Tensile/CustomKernels.py b/tensilelite/Tensile/CustomKernels.py
418index ffceb636f5..127b3386a1 100644
419--- a/tensilelite/Tensile/CustomKernels.py
420+++ b/tensilelite/Tensile/CustomKernels.py
421@@ -24,7 +24,9 @@
422
423 from . import CUSTOM_KERNEL_PATH
424 from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters
425+from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER
426
427+from functools import lru_cache
428 import yaml
429
430 import os
431@@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH):
432
433 return (config, assembly)
434
435+# getCustomKernelConfig will get called repeatedly on the same file
436+# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs
437+@lru_cache
438 def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH):
439 rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory)
440 try:
441- return yaml.safe_load(rawConfig)["custom.config"]
442+ return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"]
443 except yaml.scanner.ScannerError as e:
444 raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e))
445
446diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py
447index 835ed9c019..024c6c49c1 100644
448--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py
449+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py
450@@ -26,8 +26,10 @@ import rocisa
451
452 import functools
453 import glob
454+import gc
455 import itertools
456 import os
457+import resource
458 import shutil
459 from pathlib import Path
460 from timeit import default_timer as timer
461@@ -78,6 +80,25 @@ from Tensile.Utilities.Decorators.Timing import timing
462 from .ParseArguments import parseArguments
463
464
465+def getMemoryUsage():
466+ """Get peak and current memory usage in MB."""
467+ rusage = resource.getrusage(resource.RUSAGE_SELF)
468+ peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux
469+
470+ # Get current memory from /proc/self/status
471+ current_memory_mb = 0
472+ try:
473+ with open('/proc/self/status') as f:
474+ for line in f:
475+ if line.startswith('VmRSS:'):
476+ current_memory_mb = int(line.split()[1]) / 1024 # KB to MB
477+ break
478+ except:
479+ current_memory_mb = peak_memory_mb # Fallback
480+
481+ return (peak_memory_mb, current_memory_mb)
482+
483+
484 class KernelCodeGenResult(NamedTuple):
485 err: int
486 src: str
487@@ -115,6 +136,29 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC
488 )
489
490
491+def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, splitGSU, kernel, assemblyTmpPath, assembler):
492+ """
493+ Pipeline function for TCL mode that:
494+ 1. Generates kernel source
495+ 2. Writes .s file to disk
496+ 3. Assembles to .o file
497+ 4. Deletes .s file
498+ """
499+ result = processKernelSource(kernelWriterAssembly, rocisa_data, splitGSU, kernel)
500+ return writeAndAssembleKernel(result, assemblyTmpPath, assembler)
501+
502+
503+def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat):
504+ """
505+ Write a master solution library to disk.
506+ Module-level function to support multiprocessing.
507+ """
508+ name, lib = name_lib_tuple
509+ filename = os.path.join(newLibraryDir, name)
510+ lib.applyNaming(splitGSU)
511+ LibraryIO.write(filename, state(lib), libraryFormat)
512+
513+
514 def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool):
515 removeKernels = []
516 removeKernelNames = []
517@@ -189,6 +233,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult):
518 return path, isa, wfsize, minResult
519
520
521+def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler):
522+ """Write assembly file and immediately assemble it to .o file"""
523+ if result.err:
524+ printExit(f"Failed to build kernel {result.name} because it has error code {result.err}")
525+
526+ path = Path(asmPath) / f"{result.name}.s"
527+ with open(path, "w", encoding="utf-8") as f:
528+ f.write(result.src)
529+
530+ # Assemble .s -> .o
531+ assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o")))
532+
533+ # Delete assembly file immediately to save disk space
534+ path.unlink()
535+
536+ return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk)
537+
538+
539 def writeHelpers(
540 outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H
541 ):
542@@ -268,13 +330,14 @@ def writeSolutionsAndKernels(
543 numAsmKernels = len(asmKernels)
544 numKernels = len(asmKernels)
545 assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite"
546- asmIter = zip(
547- itertools.repeat(kernelWriterAssembly),
548- itertools.repeat(rocisa.rocIsa.getInstance().getData()),
549- itertools.repeat(splitGSU),
550- asmKernels
551+
552+ processKernelFn = functools.partial(
553+ processKernelSource,
554+ kernelWriterAssembly=kernelWriterAssembly,
555+ data=rocisa.rocIsa.getInstance().getData(),
556+ splitGSU=splitGSU
557 )
558- asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list")
559+ asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False)
560 removeInvalidSolutionsAndKernels(
561 asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU
562 )
563@@ -282,19 +345,21 @@ def writeSolutionsAndKernels(
564 asmResults, asmKernels, solutions, splitGSU
565 )
566
567- def assemble(ret):
568- p, isa, wavefrontsize, result = ret
569- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o")))
570-
571- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath)
572- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F)
573+ # Use functools.partial to bind assemblyTmpPath and assembler
574+ writeAndAssembleFn = functools.partial(
575+ writeAndAssembleKernel,
576+ asmPath=assemblyTmpPath,
577+ assembler=asmToolchain.assembler
578+ )
579 ret = ParallelMap2(
580- compose(assemble, unaryWriteAssembly),
581+ writeAndAssembleFn,
582 asmResults,
583 "Writing assembly kernels",
584 return_as="list",
585 multiArg=False,
586 )
587+ del asmResults
588+ gc.collect()
589
590 writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H)
591 srcKernelFile = Path(outputPath) / "Kernels.cpp"
592@@ -369,32 +434,31 @@ def writeSolutionsAndKernelsTCL(
593
594 uniqueAsmKernels = [k for k in asmKernels if not k.duplicate]
595
596- def assemble(ret):
597- p, isa, wavefrontsize, result = ret
598- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o")))
599- return result
600-
601- unaryProcessKernelSource = functools.partial(
602- processKernelSource,
603+ processKernelFn = functools.partial(
604+ processAndAssembleKernelTCL,
605 kernelWriterAssembly,
606 rocisa.rocIsa.getInstance().getData(),
607 splitGSU,
608+ assemblyTmpPath=assemblyTmpPath,
609+ assembler=asmToolchain.assembler
610 )
611
612- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath)
613- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F)
614- ret = ParallelMap2(
615- compose(assemble, unaryWriteAssembly, unaryProcessKernelSource),
616+ results = ParallelMap2(
617+ processKernelFn,
618 uniqueAsmKernels,
619 "Generating assembly kernels",
620 multiArg=False,
621 return_as="list"
622 )
623+ del processKernelFn
624+ gc.collect()
625+
626 passPostKernelInfoToSolution(
627- ret, uniqueAsmKernels, solutions, splitGSU
628+ results, uniqueAsmKernels, solutions, splitGSU
629 )
630- # result.src is very large so let garbage collector know to clean up
631- del ret
632+ del results
633+ gc.collect()
634+
635 buildAssemblyCodeObjectFiles(
636 asmToolchain.linker,
637 asmToolchain.bundler,
638@@ -493,6 +557,15 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa
639 return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list
640
641
642+def libraryIter(lib: MasterSolutionLibrary):
643+ if len(lib.solutions):
644+ for i, s in enumerate(lib.solutions.items()):
645+ yield (i, *s)
646+ else:
647+ for _, lazyLib in lib.lazyLibraries.items():
648+ yield from libraryIter(lazyLib)
649+
650+
651 @timing
652 def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap):
653
654@@ -508,26 +581,23 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf
655 printSolutionRejectionReason = True
656 printIndexAssignmentInfo = False
657
658- fIter = zip(
659- logicFiles,
660- itertools.repeat(assembler),
661- itertools.repeat(splitGSU),
662- itertools.repeat(printSolutionRejectionReason),
663- itertools.repeat(printIndexAssignmentInfo),
664- itertools.repeat(isaInfoMap),
665- itertools.repeat(args["LazyLibraryLoading"]),
666+ parseLogicFn = functools.partial(
667+ LibraryIO.parseLibraryLogicFile,
668+ assembler=assembler,
669+ splitGSU=splitGSU,
670+ printSolutionRejectionReason=printSolutionRejectionReason,
671+ printIndexAssignmentInfo=printIndexAssignmentInfo,
672+ isaInfoMap=isaInfoMap,
673+ lazyLibraryLoading=args["LazyLibraryLoading"]
674 )
675
676- def libraryIter(lib: MasterSolutionLibrary):
677- if len(lib.solutions):
678- for i, s in enumerate(lib.solutions.items()):
679- yield (i, *s)
680- else:
681- for _, lazyLib in lib.lazyLibraries.items():
682- yield from libraryIter(lazyLib)
683-
684 for library in ParallelMap2(
685- LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered"
686+ parseLogicFn, logicFiles, "Loading Logics...",
687+ return_as="generator_unordered",
688+ minChunkSize=24,
689+ maxWorkers=32,
690+ maxtasksperchild=1,
691+ multiArg=False,
692 ):
693 _, architectureName, _, _, _, newLibrary = library
694
695@@ -539,6 +609,9 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf
696 else:
697 masterLibraries[architectureName] = newLibrary
698 masterLibraries[architectureName].version = args["CodeObjectVersion"]
699+ del library, newLibrary
700+
701+ gc.collect()
702
703 # Sort masterLibraries to make global soln index values deterministic
704 solnReIndex = 0
705@@ -734,6 +807,9 @@ def run():
706 )
707 stop_wsk = timer()
708 print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}")
709+ numKernelHelperObjs = len(kernelHelperObjs)
710+ del kernelWriterAssembly, kernelHelperObjs
711+ gc.collect()
712
713 archs = [ # is this really different than the other archs above?
714 isaToGfx(arch)
715@@ -751,13 +827,10 @@ def run():
716 if kName not in solDict:
717 solDict["%s"%kName] = kernel
718
719- def writeMsl(name, lib):
720- filename = os.path.join(newLibraryDir, name)
721- lib.applyNaming(splitGSU)
722- LibraryIO.write(filename, state(lib), arguments["LibraryFormat"])
723-
724 filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping")
725 LibraryIO.write(filename, libraryMapping, "msgpack")
726+ del libraryMapping
727+ gc.collect()
728
729 start_msl = timer()
730 for archName, newMasterLibrary in masterLibraries.items():
731@@ -774,12 +847,22 @@ def run():
732 kName = getKeyNoInternalArgs(s.originalSolution, splitGSU)
733 s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"]
734
735- ParallelMap2(writeMsl,
736+ writeFn = functools.partial(
737+ writeMasterSolutionLibrary,
738+ newLibraryDir=newLibraryDir,
739+ splitGSU=splitGSU,
740+ libraryFormat=arguments["LibraryFormat"]
741+ )
742+
743+ ParallelMap2(writeFn,
744 newMasterLibrary.lazyLibraries.items(),
745 "Writing master solution libraries",
746+ multiArg=False,
747 return_as="list")
748 stop_msl = timer()
749 print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}")
750+ del masterLibraries, solutions, kernels, solDict
751+ gc.collect()
752
753 if not arguments["KeepBuildTmp"]:
754 buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp"
755@@ -796,8 +879,11 @@ def run():
756 print("")
757
758 stop = timer()
759+ peak_memory_mb, current_memory_mb = getMemoryUsage()
760
761 print(f"Total time (s): {(stop-start):3.2f}")
762 print(f"Total kernels processed: {numKernels}")
763 print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}")
764- print(f"KernelHelperObjs: {len(kernelHelperObjs)}")
765+ print(f"KernelHelperObjs: {numKernelHelperObjs}")
766+ print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}")
767+ print(f"Current memory usage (MB): {current_memory_mb:,.1f}")
768diff --git a/tensilelite/Tensile/TensileMergeLibrary.py b/tensilelite/Tensile/TensileMergeLibrary.py
769index e33c617b6f..ba163e9918 100644
770--- a/tensilelite/Tensile/TensileMergeLibrary.py
771+++ b/tensilelite/Tensile/TensileMergeLibrary.py
772@@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff=
773 logicsFiles[origFile] = origFile
774 logicsFiles[incFile] = incFile
775
776- iters = zip(logicsFiles.keys())
777- logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list")
778+ logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False)
779 logicsDict = {}
780 for i, _ in enumerate(logicsList):
781 logicsDict[logicsList[i][0]] = logicsList[i][1]
782diff --git a/tensilelite/Tensile/TensileUpdateLibrary.py b/tensilelite/Tensile/TensileUpdateLibrary.py
783index 5ff265d0ed..c1803a6349 100644
784--- a/tensilelite/Tensile/TensileUpdateLibrary.py
785+++ b/tensilelite/Tensile/TensileUpdateLibrary.py
786@@ -26,7 +26,7 @@ from . import LibraryIO
787 from .Tensile import addCommonArguments, argUpdatedGlobalParameters
788
789 from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \
790- globalParameters, architectureMap, ensurePath, ParallelMap, __version__
791+ globalParameters, architectureMap, ensurePath, ParallelMap2, __version__
792
793 import argparse
794 import copy
795@@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs):
796 for logicFile in logicFiles:
797 print("# %s" % logicFile)
798 fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath))
799- libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap)
800+ libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list")
801
802
803 def main():
804diff --git a/tensilelite/Tensile/Toolchain/Assembly.py b/tensilelite/Tensile/Toolchain/Assembly.py
805index a8b91e8d62..265e1d532c 100644
806--- a/tensilelite/Tensile/Toolchain/Assembly.py
807+++ b/tensilelite/Tensile/Toolchain/Assembly.py
808@@ -30,7 +30,7 @@ import subprocess
809 from pathlib import Path
810 from typing import List, Union, NamedTuple
811
812-from Tensile.Common import print2
813+from Tensile.Common import print1, print2
814 from Tensile.Common.Architectures import isaToGfx
815 from ..SolutionStructs import Solution
816
817@@ -92,8 +92,26 @@ def buildAssemblyCodeObjectFiles(
818 if coName:
819 coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj)))
820
821+ # Build reference count map for .o files to handle shared object files
822+ # (.o files from kernels marked .duplicate in TensileCreateLibrary)
823+ objFileRefCount = collections.Counter()
824+ for coFileRaw, objFiles in coFileMap.items():
825+ for objFile in objFiles:
826+ objFileRefCount[objFile] += 1
827+
828+ sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1}
829+ if sharedObjFiles:
830+ print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:")
831+
832 for coFileRaw, objFiles in coFileMap.items():
833 linker(objFiles, str(coFileRaw))
834+
835+ # Delete .o files after linking once usage count reaches 0
836+ for objFile in objFiles:
837+ objFileRefCount[objFile] -= 1
838+ if objFileRefCount[objFile] == 0:
839+ Path(objFile).unlink()
840+
841 coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo)
842 if compress:
843 bundler.compress(str(coFileRaw), str(coFile), gfx)
844diff --git a/tensilelite/Tensile/Toolchain/Component.py b/tensilelite/Tensile/Toolchain/Component.py
845index 67fa35e2d8..dde83af4c3 100644
846--- a/tensilelite/Tensile/Toolchain/Component.py
847+++ b/tensilelite/Tensile/Toolchain/Component.py
848@@ -355,6 +355,7 @@ class Linker(Component):
849 when invoking the linker, LLVM allows the provision of arguments via a "response file"
850 Reference: https://llvm.org/docs/CommandLine.html#response-files
851 """
852+ # FIXME: this prevents threading as clang_args.txt is overwritten
853 with open(Path.cwd() / "clang_args.txt", "wt") as file:
854 file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths))
855 return [*(self.default_args), "-o", destPath, "@clang_args.txt"]
856diff --git a/tensilelite/requirements.txt b/tensilelite/requirements.txt
857index 60c4c11445..5c8fd66a88 100644
858--- a/tensilelite/requirements.txt
859+++ b/tensilelite/requirements.txt
860@@ -2,8 +2,6 @@ dataclasses; python_version == '3.6'
861 packaging
862 pyyaml
863 msgpack
864-joblib>=1.4.0; python_version >= '3.8'
865-joblib>=1.1.1; python_version < '3.8'
866 simplejson
867 ujson
868 orjson