nixpkgs mirror (for testing)
github.com/NixOS/nixpkgs
nix
1diff --git a/tensilelite/Tensile/Common/Parallel.py b/tensilelite/Tensile/Common/Parallel.py
2index 1a2bf9e119..f46100c7b8 100644
3--- a/tensilelite/Tensile/Common/Parallel.py
4+++ b/tensilelite/Tensile/Common/Parallel.py
5@@ -22,43 +22,58 @@
6 #
7 ################################################################################
8
9-import concurrent.futures
10-import itertools
11+import multiprocessing
12 import os
13+import re
14 import sys
15 import time
16-
17-from joblib import Parallel, delayed
18+from functools import partial
19+from typing import Any, Callable
20
21 from .Utilities import tqdm
22
23
24-def joblibParallelSupportsGenerator():
25- import joblib
26- from packaging.version import Version
27+def get_inherited_job_limit() -> int:
28+ # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+)
29+ if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ:
30+ try:
31+ return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL'])
32+ except ValueError:
33+ pass
34
35- joblibVer = joblib.__version__
36- return Version(joblibVer) >= Version("1.4.0")
37+ # 2. Parse MAKEFLAGS for -jN
38+ makeflags = os.environ.get('MAKEFLAGS', '')
39+ match = re.search(r'-j\s*(\d+)', makeflags)
40+ if match:
41+ return int(match.group(1))
42
43+ return -1
44
45-def CPUThreadCount(enable=True):
46- from .GlobalParameters import globalParameters
47
48+def CPUThreadCount(enable=True):
49 if not enable:
50 return 1
51- else:
52+ from .GlobalParameters import globalParameters
53+
54+ # Priority order:
55+ # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS)
56+ # 2. Explicit --jobs flag
57+ # 3. Auto-detect
58+ inherited_limit = get_inherited_job_limit()
59+ cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"]
60+
61+ if cpuThreads < 1:
62 if os.name == "nt":
63- # Windows supports at most 61 workers because the scheduler uses
64- # WaitForMultipleObjects directly, which has the limit (the limit
65- # is actually 64, but some handles are needed for accounting).
66- cpu_count = min(os.cpu_count(), 61)
67+ cpuThreads = os.cpu_count()
68 else:
69- cpu_count = len(os.sched_getaffinity(0))
70- cpuThreads = globalParameters["CpuThreads"]
71- if cpuThreads == -1:
72- return cpu_count
73+ cpuThreads = len(os.sched_getaffinity(0))
74
75- return min(cpu_count, cpuThreads)
76+ if os.name == "nt":
77+ # Windows supports at most 61 workers because the scheduler uses
78+ # WaitForMultipleObjects directly, which has the limit (the limit
79+ # is actually 64, but some handles are needed for accounting).
80+ cpuThreads = min(cpuThreads, 61)
81+ return max(1, cpuThreads)
82
83
84 def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters):
85@@ -71,19 +86,22 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters):
86 return f(arg)
87
88
89-def apply_print_exception(item, *args):
90- # print(item, args)
91+def OverwriteGlobalParameters(newGlobalParameters):
92+ from . import GlobalParameters
93+
94+ GlobalParameters.globalParameters.clear()
95+ GlobalParameters.globalParameters.update(newGlobalParameters)
96+
97+
98+def worker_function(args, function, multiArg):
99+ """Worker function that executes in the pool process."""
100 try:
101- if len(args) > 0:
102- func = item
103- args = args[0]
104- return func(*args)
105+ if multiArg:
106+ return function(*args)
107 else:
108- func, item = item
109- return func(item)
110+ return function(args)
111 except Exception:
112 import traceback
113-
114 traceback.print_exc()
115 raise
116 finally:
117@@ -98,154 +116,121 @@ def OverwriteGlobalParameters(newGlobalParameters):
118 GlobalParameters.globalParameters.update(newGlobalParameters)
119
120
121-def ProcessingPool(enable=True, maxTasksPerChild=None):
122- import multiprocessing
123- import multiprocessing.dummy
124-
125- threadCount = CPUThreadCount()
126-
127- if (not enable) or threadCount <= 1:
128- return multiprocessing.dummy.Pool(1)
129-
130- if multiprocessing.get_start_method() == "spawn":
131- from . import GlobalParameters
132-
133- return multiprocessing.Pool(
134- threadCount,
135- initializer=OverwriteGlobalParameters,
136- maxtasksperchild=maxTasksPerChild,
137- initargs=(GlobalParameters.globalParameters,),
138- )
139- else:
140- return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild)
141+def progress_logger(iterable, total, message, min_log_interval=5.0):
142+ """
143+ Generator that wraps an iterable and logs progress with time-based throttling.
144
145+ Only logs progress if at least min_log_interval seconds have passed since last log.
146+ Only prints completion message if task took >= min_log_interval seconds.
147
148-def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None):
149+ Yields (index, item) tuples.
150 """
151- Generally equivalent to list(map(function, objects)), possibly executing in parallel.
152-
153- message: A message describing the operation to be performed.
154- enable: May be set to false to disable parallelism.
155- method: A function which can fetch the mapping function from a processing pool object.
156- Leave blank to use .map(), other possiblities:
157- - `lambda x: x.starmap` - useful if `function` takes multiple parameters.
158- - `lambda x: x.imap` - lazy evaluation
159- - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value.
160- """
161- from .GlobalParameters import globalParameters
162+ start_time = time.time()
163+ last_log_time = start_time
164+ log_interval = 1 + (total // 100)
165
166- threadCount = CPUThreadCount(enable)
167- pool = ProcessingPool(enable, maxTasksPerChild)
168-
169- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
170- # Provide a progress bar for single-threaded operation.
171- # This works for method=None, and for starmap.
172- mapFunc = map
173- if method is not None:
174- # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7.
175- # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded
176- # operation.
177- try:
178- mapFunc = method(itertools)
179- except NameError:
180- mapFunc = None
181-
182- if mapFunc is not None:
183- return list(mapFunc(function, tqdm(objects, message)))
184-
185- mapFunc = pool.map
186- if method:
187- mapFunc = method(pool)
188-
189- objects = zip(itertools.repeat(function), objects)
190- function = apply_print_exception
191-
192- countMessage = ""
193- try:
194- countMessage = " for {} tasks".format(len(objects))
195- except TypeError:
196- pass
197+ for idx, item in enumerate(iterable):
198+ if idx % log_interval == 0:
199+ current_time = time.time()
200+ if (current_time - last_log_time) >= min_log_interval:
201+ print(f"{message}\t{idx+1: 5d}/{total: 5d}")
202+ last_log_time = current_time
203+ yield idx, item
204
205- if message != "":
206- message += ": "
207+ elapsed = time.time() - start_time
208+ final_idx = idx + 1 if 'idx' in locals() else 0
209
210- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
211- sys.stdout.flush()
212- currentTime = time.time()
213- rv = mapFunc(function, objects)
214- totalTime = time.time() - currentTime
215- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
216- sys.stdout.flush()
217- pool.close()
218- return rv
219+ if elapsed >= min_log_interval or last_log_time > start_time:
220+ print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}")
221
222
223-def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True):
224- from .GlobalParameters import globalParameters
225+def imap_with_progress(pool, func, iterable, total, message, chunksize):
226+ results = []
227+ for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message):
228+ results.append(result)
229+ return results
230
231- threadCount = CPUThreadCount(enable)
232- print("{0}Launching {1} threads...".format(message, threadCount))
233
234- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
235- # Provide a progress bar for single-threaded operation.
236- callFunc = lambda args: function(*args) if multiArg else lambda args: function(args)
237- return [callFunc(args) for args in tqdm(objects, message)]
238+def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild):
239+ # separate fn because yield makes the entire fn a generator even if unreachable
240+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')
241
242- with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor:
243- resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects)
244- for result in concurrent.futures.as_completed(resultFutures):
245- yield result.result()
246+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
247+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
248+ for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message):
249+ yield result
250
251
252 def ParallelMap2(
253- function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None
254+ function: Callable,
255+ objects: Any,
256+ message: str = "",
257+ enable: bool = True,
258+ multiArg: bool = True,
259+ minChunkSize: int = 1,
260+ maxWorkers: int = -1,
261+ maxtasksperchild: int = 1024,
262+ return_as: str = "list"
263 ):
264+ """Executes a function over a list of objects in parallel or sequentially.
265+
266+ This function is generally equivalent to ``list(map(function, objects))``. However, it provides
267+ additional functionality to run in parallel, depending on the 'enable' flag and available CPU
268+ threads.
269+
270+ Args:
271+ function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function'
272+ should accept multiple arguments.
273+ objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each
274+ item in 'objects' should be an iterable of arguments for 'function'.
275+ message: Optional; a message describing the operation. Default is an empty string.
276+ enable: Optional; if False, disables parallel execution and runs sequentially. Default is True.
277+ multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for
278+ 'function'. Default is True.
279+ return_as: Optional; "list" (default) or "generator_unordered" for streaming results
280+
281+ Returns:
282+ A list or generator containing the results of applying **function** to each item in **objects**.
283 """
284- Generally equivalent to list(map(function, objects)), possibly executing in parallel.
285+ from .GlobalParameters import globalParameters
286
287- message: A message describing the operation to be performed.
288- enable: May be set to false to disable parallelism.
289- multiArg: True if objects represent multiple arguments
290- (differentiates multi args vs single collection arg)
291- """
292- if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator():
293- return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg)
294+ threadCount = CPUThreadCount(enable)
295
296- from .GlobalParameters import globalParameters
297+ if not hasattr(objects, "__len__"):
298+ objects = list(objects)
299
300- threadCount = procs if procs else CPUThreadCount(enable)
301+ objLen = len(objects)
302+ if objLen == 0:
303+ return [] if return_as == "list" else iter([])
304
305- threadCount = CPUThreadCount(enable)
306+ f = (lambda x: function(*x)) if multiArg else function
307+ if objLen == 1:
308+ print(f"{message}: (1 task)")
309+ result = [f(x) for x in objects]
310+ return result if return_as == "list" else iter(result)
311
312- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
313- # Provide a progress bar for single-threaded operation.
314- return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)]
315+ extra_message = (
316+ f": {threadCount} thread(s)" + f", {objLen} tasks"
317+ if objLen
318+ else ""
319+ )
320
321- countMessage = ""
322- try:
323- countMessage = " for {} tasks".format(len(objects))
324- except TypeError:
325- pass
326-
327- if message != "":
328- message += ": "
329- print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
330- sys.stdout.flush()
331- currentTime = time.time()
332-
333- pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg
334- pargs = zip(objects, itertools.repeat(globalParameters))
335-
336- if joblibParallelSupportsGenerator():
337- rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)(
338- delayed(pcall)(function, a, params) for a, params in pargs
339- )
340+ print(f"ParallelMap {message}{extra_message}")
341+
342+ if threadCount <= 1:
343+ result = [f(x) for x in objects]
344+ return result if return_as == "list" else iter(result)
345+
346+ if maxWorkers > 0:
347+ threadCount = min(maxWorkers, threadCount)
348+
349+ chunksize = max(minChunkSize, objLen // 2000)
350+ worker = partial(worker_function, function=function, multiArg=multiArg)
351+ if return_as == "generator_unordered":
352+ # yield results as they complete without buffering
353+ return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild)
354 else:
355- rv = Parallel(n_jobs=threadCount, timeout=99999)(
356- delayed(pcall)(function, a, params) for a, params in pargs
357- )
358-
359- totalTime = time.time() - currentTime
360- print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
361- sys.stdout.flush()
362- return rv
363+ ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')
364+ with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
365+ initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
366+ return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize))
367diff --git a/tensilelite/Tensile/CustomKernels.py b/tensilelite/Tensile/CustomKernels.py
368index ffceb636f5..127b3386a1 100644
369--- a/tensilelite/Tensile/CustomKernels.py
370+++ b/tensilelite/Tensile/CustomKernels.py
371@@ -24,7 +24,9 @@
372
373 from . import CUSTOM_KERNEL_PATH
374 from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters
375+from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER
376
377+from functools import lru_cache
378 import yaml
379
380 import os
381@@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH):
382
383 return (config, assembly)
384
385+# getCustomKernelConfig will get called repeatedly on the same file
386+# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs
387+@lru_cache
388 def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH):
389 rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory)
390 try:
391- return yaml.safe_load(rawConfig)["custom.config"]
392+ return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"]
393 except yaml.scanner.ScannerError as e:
394 raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e))
395
396diff --git a/tensilelite/Tensile/CustomYamlLoader.py b/tensilelite/Tensile/CustomYamlLoader.py
397index e03f456fbe..ed7510f2ce 100644
398--- a/tensilelite/Tensile/CustomYamlLoader.py
399+++ b/tensilelite/Tensile/CustomYamlLoader.py
400@@ -1,6 +1,7 @@
401 # Copyright © Advanced Micro Devices, Inc., or its affiliates.
402 # SPDX-License-Identifier: MIT
403
404+import sys
405 import yaml
406 from pathlib import Path
407
408@@ -70,7 +71,7 @@ def parse_scalar(loader: yaml.Loader):
409 elif is_float(value_lower):
410 return float(value_lower)
411
412- return value
413+ return sys.intern(value)
414
415 def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader):
416 with open(yaml_path, 'r') as f:
417diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py
418index 22d19851a3..348068b3bf 100644
419--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py
420+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py
421@@ -26,8 +26,10 @@ import rocisa
422
423 import functools
424 import glob
425+import gc
426 import itertools
427 import os
428+import resource
429 import shutil
430 from pathlib import Path
431 from timeit import default_timer as timer
432@@ -78,6 +80,25 @@ from Tensile.Utilities.Decorators.Timing import timing
433 from .ParseArguments import parseArguments
434
435
436+def getMemoryUsage():
437+ """Get peak and current memory usage in MB."""
438+ rusage = resource.getrusage(resource.RUSAGE_SELF)
439+ peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux
440+
441+ # Get current memory from /proc/self/status
442+ current_memory_mb = 0
443+ try:
444+ with open('/proc/self/status') as f:
445+ for line in f:
446+ if line.startswith('VmRSS:'):
447+ current_memory_mb = int(line.split()[1]) / 1024 # KB to MB
448+ break
449+ except:
450+ current_memory_mb = peak_memory_mb # Fallback
451+
452+ return (peak_memory_mb, current_memory_mb)
453+
454+
455 class KernelCodeGenResult(NamedTuple):
456 err: int
457 src: str
458@@ -115,6 +136,29 @@ def processKernelSource(kernelWriterAssembly, data, outOptions, splitGSU, kernel
459 )
460
461
462+def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, outOptions, splitGSU, kernel, assemblyTmpPath, assembler):
463+ """
464+ Pipeline function for TCL mode that:
465+ 1. Generates kernel source
466+ 2. Writes .s file to disk
467+ 3. Assembles to .o file
468+ 4. Deletes .s file
469+ """
470+ result = processKernelSource(kernelWriterAssembly, rocisa_data, outOptions, splitGSU, kernel)
471+ return writeAndAssembleKernel(result, assemblyTmpPath, assembler)
472+
473+
474+def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat):
475+ """
476+ Write a master solution library to disk.
477+ Module-level function to support multiprocessing.
478+ """
479+ name, lib = name_lib_tuple
480+ filename = os.path.join(newLibraryDir, name)
481+ lib.applyNaming(splitGSU)
482+ LibraryIO.write(filename, state(lib), libraryFormat)
483+
484+
485 def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool):
486 removeKernels = []
487 removeKernelNames = []
488@@ -189,6 +233,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult):
489 return path, isa, wfsize, minResult
490
491
492+def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler):
493+ """Write assembly file and immediately assemble it to .o file"""
494+ if result.err:
495+ printExit(f"Failed to build kernel {result.name} because it has error code {result.err}")
496+
497+ path = Path(asmPath) / f"{result.name}.s"
498+ with open(path, "w", encoding="utf-8") as f:
499+ f.write(result.src)
500+
501+ # Assemble .s -> .o
502+ assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o")))
503+
504+ # Delete assembly file immediately to save disk space
505+ path.unlink()
506+
507+ return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk)
508+
509+
510 def writeHelpers(
511 outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H
512 ):
513@@ -272,14 +334,15 @@ def writeSolutionsAndKernels(
514 numAsmKernels = len(asmKernels)
515 numKernels = len(asmKernels)
516 assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite"
517- asmIter = zip(
518- itertools.repeat(kernelWriterAssembly),
519- itertools.repeat(rocisa.rocIsa.getInstance().getData()),
520- itertools.repeat(outOptions),
521- itertools.repeat(splitGSU),
522- asmKernels
523+
524+ processKernelFn = functools.partial(
525+ processKernelSource,
526+ kernelWriterAssembly=kernelWriterAssembly,
527+ data=rocisa.rocIsa.getInstance().getData(),
528+ outOptions=outOptions,
529+ splitGSU=splitGSU
530 )
531- asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list")
532+ asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False)
533 removeInvalidSolutionsAndKernels(
534 asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU
535 )
536@@ -287,19 +350,21 @@ def writeSolutionsAndKernels(
537 asmResults, asmKernels, solutions, splitGSU
538 )
539
540- def assemble(ret):
541- p, isa, wavefrontsize, _ = ret
542- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o")))
543-
544- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath)
545- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F)
546+ # Use functools.partial to bind assemblyTmpPath and assembler
547+ writeAndAssembleFn = functools.partial(
548+ writeAndAssembleKernel,
549+ asmPath=assemblyTmpPath,
550+ assembler=asmToolchain.assembler
551+ )
552 ret = ParallelMap2(
553- compose(assemble, unaryWriteAssembly),
554+ writeAndAssembleFn,
555 asmResults,
556 "Writing assembly kernels",
557 return_as="list",
558 multiArg=False,
559 )
560+ del asmResults
561+ gc.collect()
562
563 writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H)
564 srcKernelFile = Path(outputPath) / "Kernels.cpp"
565@@ -376,40 +441,35 @@ def writeSolutionsAndKernelsTCL(
566
567 uniqueAsmKernels = [k for k in asmKernels if not k.duplicate]
568
569- def assemble(ret, removeTemporaries: bool):
570- asmPath, isa, wavefrontsize, result = ret
571- asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(asmPath), str(asmPath.with_suffix(".o")))
572- if removeTemporaries:
573- asmPath.unlink()
574- return result
575-
576- unaryAssemble = functools.partial(assemble, removeTemporaries=removeTemporaries)
577-
578 outOptions = rocisa.rocIsa.getInstance().getOutputOptions()
579 outOptions.outputNoComment = not disableAsmComments
580
581- unaryProcessKernelSource = functools.partial(
582- processKernelSource,
583+ processKernelFn = functools.partial(
584+ processAndAssembleKernelTCL,
585 kernelWriterAssembly,
586 rocisa.rocIsa.getInstance().getData(),
587 outOptions,
588 splitGSU,
589+ assemblyTmpPath=assemblyTmpPath,
590+ assembler=asmToolchain.assembler
591 )
592
593- unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath)
594- compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F)
595- ret = ParallelMap2(
596- compose(unaryAssemble, unaryWriteAssembly, unaryProcessKernelSource),
597+ results = ParallelMap2(
598+ processKernelFn,
599 uniqueAsmKernels,
600 "Generating assembly kernels",
601 multiArg=False,
602 return_as="list"
603 )
604+ del processKernelFn
605+ gc.collect()
606+
607 passPostKernelInfoToSolution(
608- ret, uniqueAsmKernels, solutions, splitGSU
609+ results, uniqueAsmKernels, solutions, splitGSU
610 )
611- # result.src is very large so let garbage collector know to clean up
612- del ret
613+ del results
614+ gc.collect()
615+
616 buildAssemblyCodeObjectFiles(
617 asmToolchain.linker,
618 asmToolchain.bundler,
619@@ -508,6 +568,15 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa
620 return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list
621
622
623+def libraryIter(lib: MasterSolutionLibrary):
624+ if len(lib.solutions):
625+ for i, s in enumerate(lib.solutions.items()):
626+ yield (i, *s)
627+ else:
628+ for _, lazyLib in lib.lazyLibraries.items():
629+ yield from libraryIter(lazyLib)
630+
631+
632 @timing
633 def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap):
634
635@@ -523,26 +592,23 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf
636 printSolutionRejectionReason = True
637 printIndexAssignmentInfo = False
638
639- fIter = zip(
640- logicFiles,
641- itertools.repeat(assembler),
642- itertools.repeat(splitGSU),
643- itertools.repeat(printSolutionRejectionReason),
644- itertools.repeat(printIndexAssignmentInfo),
645- itertools.repeat(isaInfoMap),
646- itertools.repeat(args["LazyLibraryLoading"]),
647+ parseLogicFn = functools.partial(
648+ LibraryIO.parseLibraryLogicFile,
649+ assembler=assembler,
650+ splitGSU=splitGSU,
651+ printSolutionRejectionReason=printSolutionRejectionReason,
652+ printIndexAssignmentInfo=printIndexAssignmentInfo,
653+ isaInfoMap=isaInfoMap,
654+ lazyLibraryLoading=args["LazyLibraryLoading"]
655 )
656
657- def libraryIter(lib: MasterSolutionLibrary):
658- if len(lib.solutions):
659- for i, s in enumerate(lib.solutions.items()):
660- yield (i, *s)
661- else:
662- for _, lazyLib in lib.lazyLibraries.items():
663- yield from libraryIter(lazyLib)
664-
665 for library in ParallelMap2(
666- LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered"
667+ parseLogicFn, logicFiles, "Loading Logics...",
668+ return_as="generator_unordered",
669+ minChunkSize=24,
670+ maxWorkers=32,
671+ maxtasksperchild=1,
672+ multiArg=False,
673 ):
674 _, architectureName, _, _, _, newLibrary = library
675
676@@ -554,6 +620,9 @@ def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInf
677 else:
678 masterLibraries[architectureName] = newLibrary
679 masterLibraries[architectureName].version = args["CodeObjectVersion"]
680+ del library, newLibrary
681+
682+ gc.collect()
683
684 # Sort masterLibraries to make global soln index values deterministic
685 solnReIndex = 0
686@@ -751,6 +820,9 @@ def run():
687 )
688 stop_wsk = timer()
689 print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}")
690+ numKernelHelperObjs = len(kernelHelperObjs)
691+ del kernelWriterAssembly, kernelHelperObjs
692+ gc.collect()
693
694 archs = [ # is this really different than the other archs above?
695 isaToGfx(arch)
696@@ -768,13 +840,10 @@ def run():
697 if kName not in solDict:
698 solDict["%s"%kName] = kernel
699
700- def writeMsl(name, lib):
701- filename = os.path.join(newLibraryDir, name)
702- lib.applyNaming(splitGSU)
703- LibraryIO.write(filename, state(lib), arguments["LibraryFormat"])
704-
705 filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping")
706 LibraryIO.write(filename, libraryMapping, "msgpack")
707+ del libraryMapping
708+ gc.collect()
709
710 start_msl = timer()
711 for archName, newMasterLibrary in masterLibraries.items():
712@@ -791,12 +860,22 @@ def run():
713 kName = getKeyNoInternalArgs(s.originalSolution, splitGSU)
714 s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"]
715
716- ParallelMap2(writeMsl,
717+ writeFn = functools.partial(
718+ writeMasterSolutionLibrary,
719+ newLibraryDir=newLibraryDir,
720+ splitGSU=splitGSU,
721+ libraryFormat=arguments["LibraryFormat"]
722+ )
723+
724+ ParallelMap2(writeFn,
725 newMasterLibrary.lazyLibraries.items(),
726 "Writing master solution libraries",
727+ multiArg=False,
728 return_as="list")
729 stop_msl = timer()
730 print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}")
731+ del masterLibraries, solutions, kernels, solDict
732+ gc.collect()
733
734 if not arguments["KeepBuildTmp"]:
735 buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp"
736@@ -813,8 +892,11 @@ def run():
737 print("")
738
739 stop = timer()
740+ peak_memory_mb, current_memory_mb = getMemoryUsage()
741
742 print(f"Total time (s): {(stop-start):3.2f}")
743 print(f"Total kernels processed: {numKernels}")
744 print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}")
745- print(f"KernelHelperObjs: {len(kernelHelperObjs)}")
746+ print(f"KernelHelperObjs: {numKernelHelperObjs}")
747+ print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}")
748+ print(f"Current memory usage (MB): {current_memory_mb:,.1f}")
749diff --git a/tensilelite/Tensile/TensileMergeLibrary.py b/tensilelite/Tensile/TensileMergeLibrary.py
750index e33c617b6f..ba163e9918 100644
751--- a/tensilelite/Tensile/TensileMergeLibrary.py
752+++ b/tensilelite/Tensile/TensileMergeLibrary.py
753@@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff=
754 logicsFiles[origFile] = origFile
755 logicsFiles[incFile] = incFile
756
757- iters = zip(logicsFiles.keys())
758- logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list")
759+ logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False)
760 logicsDict = {}
761 for i, _ in enumerate(logicsList):
762 logicsDict[logicsList[i][0]] = logicsList[i][1]
763diff --git a/tensilelite/Tensile/TensileUpdateLibrary.py b/tensilelite/Tensile/TensileUpdateLibrary.py
764index 5ff265d0ed..c1803a6349 100644
765--- a/tensilelite/Tensile/TensileUpdateLibrary.py
766+++ b/tensilelite/Tensile/TensileUpdateLibrary.py
767@@ -26,7 +26,7 @@ from . import LibraryIO
768 from .Tensile import addCommonArguments, argUpdatedGlobalParameters
769
770 from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \
771- globalParameters, architectureMap, ensurePath, ParallelMap, __version__
772+ globalParameters, architectureMap, ensurePath, ParallelMap2, __version__
773
774 import argparse
775 import copy
776@@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs):
777 for logicFile in logicFiles:
778 print("# %s" % logicFile)
779 fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath))
780- libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap)
781+ libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list")
782
783
784 def main():
785diff --git a/tensilelite/Tensile/Toolchain/Assembly.py b/tensilelite/Tensile/Toolchain/Assembly.py
786index a8b91e8d62..265e1d532c 100644
787--- a/tensilelite/Tensile/Toolchain/Assembly.py
788+++ b/tensilelite/Tensile/Toolchain/Assembly.py
789@@ -30,7 +30,7 @@ import subprocess
790 from pathlib import Path
791 from typing import List, Union, NamedTuple
792
793-from Tensile.Common import print2
794+from Tensile.Common import print1, print2
795 from Tensile.Common.Architectures import isaToGfx
796 from ..SolutionStructs import Solution
797
798@@ -92,8 +92,26 @@ def buildAssemblyCodeObjectFiles(
799 if coName:
800 coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj)))
801
802+ # Build reference count map for .o files to handle shared object files
803+ # (.o files from kernels marked .duplicate in TensileCreateLibrary)
804+ objFileRefCount = collections.Counter()
805+ for coFileRaw, objFiles in coFileMap.items():
806+ for objFile in objFiles:
807+ objFileRefCount[objFile] += 1
808+
809+ sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1}
810+ if sharedObjFiles:
811+ print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:")
812+
813 for coFileRaw, objFiles in coFileMap.items():
814 linker(objFiles, str(coFileRaw))
815+
816+ # Delete .o files after linking once usage count reaches 0
817+ for objFile in objFiles:
818+ objFileRefCount[objFile] -= 1
819+ if objFileRefCount[objFile] == 0:
820+ Path(objFile).unlink()
821+
822 coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo)
823 if compress:
824 bundler.compress(str(coFileRaw), str(coFile), gfx)
825diff --git a/tensilelite/Tensile/Toolchain/Component.py b/tensilelite/Tensile/Toolchain/Component.py
826index 67fa35e2d8..dde83af4c3 100644
827--- a/tensilelite/Tensile/Toolchain/Component.py
828+++ b/tensilelite/Tensile/Toolchain/Component.py
829@@ -355,6 +355,7 @@ class Linker(Component):
830 when invoking the linker, LLVM allows the provision of arguments via a "response file"
831 Reference: https://llvm.org/docs/CommandLine.html#response-files
832 """
833+ # FIXME: this prevents threading as clang_args.txt is overwritten
834 with open(Path.cwd() / "clang_args.txt", "wt") as file:
835 file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths))
836 return [*(self.default_args), "-o", destPath, "@clang_args.txt"]
837diff --git a/tensilelite/requirements.txt b/tensilelite/requirements.txt
838index 60c4c11445..5c8fd66a88 100644
839--- a/tensilelite/requirements.txt
840+++ b/tensilelite/requirements.txt
841@@ -2,8 +2,6 @@ dataclasses; python_version == '3.6'
842 packaging
843 pyyaml
844 msgpack
845-joblib>=1.4.0; python_version >= '3.8'
846-joblib>=1.1.1; python_version < '3.8'
847 simplejson
848 ujson
849 orjson