nixpkgs mirror (for testing) github.com/NixOS/nixpkgs
nix
at python-updates 181 lines 6.8 kB view raw
1diff --git a/tensilelite/Tensile/Common/Utilities.py b/tensilelite/Tensile/Common/Utilities.py 2index 0a9d9db5b3..cb9779eaac 100644 3--- a/tensilelite/Tensile/Common/Utilities.py 4+++ b/tensilelite/Tensile/Common/Utilities.py 5@@ -24,6 +24,7 @@ 6 7 import functools 8 import math 9+import operator 10 import os 11 import sys 12 import time 13@@ -269,8 +270,20 @@ def state(obj): 14 15 16 def state_key_ordering(cls): 17- def tup(obj): 18- return tuple([getattr(obj, k) for k in cls.StateKeys]) 19+ # Use operator.attrgetter for efficiency if __slots__ is defined 20+ if hasattr(cls, '__slots__'): 21+ # attrgetter is faster for slotted classes 22+ getter = operator.attrgetter(*cls.StateKeys) 23+ if len(cls.StateKeys) == 1: 24+ # attrgetter returns scalar for single key, we need tuple 25+ def tup(obj): 26+ return (getter(obj),) 27+ else: 28+ tup = getter 29+ else: 30+ # Fallback for regular classes 31+ def tup(obj): 32+ return tuple([getattr(obj, k) for k in cls.StateKeys]) 33 34 def lt(a, b): 35 return tup(a) < tup(b) 36diff --git a/tensilelite/Tensile/Contractions.py b/tensilelite/Tensile/Contractions.py 37index c0d4e851b1..3f2c2e98c6 100644 38--- a/tensilelite/Tensile/Contractions.py 39+++ b/tensilelite/Tensile/Contractions.py 40@@ -37,9 +37,60 @@ from Tensile.Toolchain.Component import Assembler 41 from math import ceil 42 43 MIN_K_FOR_GSU = 32 44+ 45+# Interning helpers to reduce memory usage by reusing identical objects 46+_free_index_cache = {} 47+def intern_free_index(isA, i=None, c=None, d=None, a=None, b=None): 48+ key = (isA, i, c, d, a, b) 49+ if key not in _free_index_cache: 50+ obj = FreeIndex(isA, i, c, d) 51+ obj.a = a 52+ if b is not None: 53+ obj.b = b 54+ _free_index_cache[key] = obj 55+ return _free_index_cache[key] 56+ 57+_batch_index_cache = {} 58+def intern_batch_index(a=None, b=None, c=None, d=None): 59+ key = (a, b, c, d) 60+ if key not in _batch_index_cache: 61+ obj = BatchIndex(c=c, d=d) 62+ obj.a = a 63+ obj.b = b 64+ _batch_index_cache[key] = obj 65+ return _batch_index_cache[key] 66+ 67+_bound_index_cache = {} 68+def intern_bound_index(a=None, b=None, aMirror=False, bMirror=False): 69+ key = (a, b, aMirror, bMirror) 70+ if key not in _bound_index_cache: 71+ obj = BoundIndex(aMirror=aMirror, bMirror=bMirror) 72+ obj.a = a 73+ obj.b = b 74+ _bound_index_cache[key] = obj 75+ return _bound_index_cache[key] 76+ 77+_size_mapping_cache = {} 78+def intern_size_mapping(size_mapping): 79+ """Intern a SizeMapping instance to reduce redundancy.""" 80+ # Build hashable key from StateKeys, converting lists to tuples 81+ key_parts = [] 82+ for attr in size_mapping.StateKeys: 83+ val = getattr(size_mapping, attr) 84+ # Convert lists to tuples for hashing 85+ if isinstance(val, list): 86+ val = tuple(val) 87+ key_parts.append(val) 88+ key = tuple(key_parts) 89+ 90+ if key not in _size_mapping_cache: 91+ _size_mapping_cache[key] = size_mapping 92+ return _size_mapping_cache[key] 93+ 94 @state_key_ordering 95 class FreeIndex: 96 StateKeys = ['isA', 'i', 'c', 'd'] 97+ __slots__ = ['isA', 'i', 'c', 'd', 'a', 'b'] 98 99 def __init__(self, isA, i=None, c=None, d=None): 100 self.isA = isA 101@@ -50,6 +101,7 @@ class FreeIndex: 102 @state_key_ordering 103 class BatchIndex: 104 StateKeys = ['a', 'b', 'c', 'd'] 105+ __slots__ = ['a', 'b', 'c', 'd'] 106 def __init__(self, a=None, b=None, c=None, d=None): 107 self.a = a 108 self.b = b 109@@ -59,6 +111,7 @@ class BatchIndex: 110 @state_key_ordering 111 class BoundIndex: 112 StateKeys = ['a', 'b', 'aMirror', 'bMirror'] 113+ __slots__ = ['a', 'b', 'aMirror', 'bMirror'] 114 def __init__(self, a=None, b=None, aMirror=False, bMirror=False): 115 self.a = a 116 self.b = b 117@@ -107,6 +160,23 @@ class ProblemType: 118 for ib, ic in enumerate(d['IndexAssignmentsB']): 119 indices[ic].b = ib 120 121+ # Now intern all indices with their final state (including .a and .b) 122+ for i, idx in enumerate(indices): 123+ if isinstance(idx, FreeIndex): 124+ indices[i] = intern_free_index(idx.isA, idx.i, idx.c, idx.d, 125+ getattr(idx, 'a', None), getattr(idx, 'b', None)) 126+ elif isinstance(idx, BatchIndex): 127+ indices[i] = intern_batch_index(getattr(idx, 'a', None), getattr(idx, 'b', None), 128+ idx.c, idx.d) 129+ elif isinstance(idx, BoundIndex): 130+ indices[i] = intern_bound_index(getattr(idx, 'a', None), getattr(idx, 'b', None), 131+ idx.aMirror, idx.bMirror) 132+ 133+ # Update the lists with interned versions 134+ freeIndices = [idx for idx in indices if isinstance(idx, FreeIndex)] 135+ batchIndices = [idx for idx in indices if isinstance(idx, BatchIndex)] 136+ boundIndices = [idx for idx in indices if isinstance(idx, BoundIndex)] 137+ 138 for idx in indices: 139 assert idx is not None 140 idxState = state(idx) 141@@ -596,6 +666,7 @@ class SizeMapping: 142 'nonTemporalA', 143 'nonTemporalB', 144 ] 145+ __slots__ = StateKeys 146 147 @classmethod 148 def FromOriginalState(cls, d): 149@@ -751,7 +822,7 @@ class Solution: 150 info = cls.ReadOriginalInfo(d) 151 rv.libraryLogicIndex = int(info.get("SolutionIndex", -1)) 152 153- rv.sizeMapping = SizeMapping.FromOriginalState(d) 154+ rv.sizeMapping = intern_size_mapping(SizeMapping.FromOriginalState(d)) 155 156 rv.internalArgsSupport = InternalArgsSupport.FromOriginalState(d) 157 158diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py 159index 730b6b1fff..b0068563a0 100644 160--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py 161+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py 162@@ -104,7 +104,6 @@ class KernelCodeGenResult(NamedTuple): 163 src: str 164 header: Optional[str] 165 name: str 166- targetObjFilename: str 167 isa: IsaVersion 168 wavefrontSize: int 169 cuoccupancy: int 170@@ -127,10 +126,9 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC 171 asmFilename = getKernelFileBase(splitGSU, kernel) 172 err, src = kernelWriter.getSourceFileString(kernel) 173 header = kernelWriter.getHeaderFileString(kernel) 174- objFilename = kernel._state.get("codeObjectFile", None) 175 pgr = int(kernel["PrefetchGlobalRead"]) 176 return KernelCodeGenResult( 177- err, src, header, asmFilename, objFilename, tuple(kernel["ISA"]), \ 178+ err, src, header, asmFilename, tuple(kernel["ISA"]), \ 179 kernel["WavefrontSize"], kernel["CUOccupancy"], \ 180 pgr, kernel["MathClocksUnrolledLoop"] 181 )