A game about forced loneliness, made by TACStudios
at master 537 lines 26 kB view raw
1using System; 2using System.Collections.Generic; 3using System.IO; 4using System.Linq; 5using Burst.Compiler.IL.Syntax; 6using Mono.Cecil; 7using Mono.Cecil.Cil; 8using Mono.Cecil.Rocks; 9 10namespace zzzUnity.Burst.CodeGen 11{ 12 /// <summary> 13 /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back. 14 /// </summary> 15 internal class FunctionPointerInvokeTransform 16 { 17 private struct CaptureInformation 18 { 19 public MethodReference Operand; 20 21 public List<Instruction> Captured; 22 } 23 24 private Dictionary<TypeReference, (MethodDefinition method, Instruction instruction)> _needsNativeFunctionPointer; 25 private Dictionary<MethodDefinition, TypeReference> _needsIl2cppInvoke; 26 private Dictionary<MethodDefinition, List<CaptureInformation>> _capturedSets; 27 private MethodDefinition _monoPInvokeAttributeCtorDef; 28 private MethodDefinition _nativePInvokeAttributeCtorDef; 29 private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef; 30 private TypeReference _burstFunctionPointerType; 31 private TypeReference _burstCompilerType; 32 private TypeReference _systemType; 33 private TypeReference _callingConventionType; 34 35 private LogDelegate _debugLog; 36 private int _logLevel; 37 38 private AssemblyResolver _loader; 39 40 private ErrorDiagnosticDelegate _errorReport; 41 42 public readonly static bool enableInvokeAttribute = true; 43 public readonly static bool enableCalliOptimisation = false; // For now only run the pass on dots player/tiny 44 public readonly static bool enableUnmangedFunctionPointerInject = true; 45 46 public FunctionPointerInvokeTransform(AssemblyResolver loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0) 47 { 48 _loader = loader; 49 50 _needsNativeFunctionPointer = new Dictionary<TypeReference, (MethodDefinition, Instruction)>(); 51 _needsIl2cppInvoke = new Dictionary<MethodDefinition, TypeReference>(); 52 _capturedSets = new Dictionary<MethodDefinition, List<CaptureInformation>>(); 53 _monoPInvokeAttributeCtorDef = null; 54 _unmanagedFunctionPointerAttributeCtorDef = null; 55 _nativePInvokeAttributeCtorDef = null; // Only present on DOTS_PLAYER 56 _burstFunctionPointerType = null; 57 _burstCompilerType = null; 58 _systemType = null; 59 _callingConventionType = null; 60 _debugLog = log; 61 _logLevel = logLevel; 62 _errorReport = error; 63 } 64 65 private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName) 66 { 67 if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result)) 68 { 69 return result; 70 } 71 return null; 72 } 73 74 public void Initialize(AssemblyResolver loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem) 75 { 76 if (_monoPInvokeAttributeCtorDef == null) 77 { 78 var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst"); 79 80 _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1"); 81 _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler"); 82 83 var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference); 84 // If the corLibrary is a redirecting assembly, then the type isn't present in Types 85 // and GetType() will therefore not find it, so instead we'll have to look it up in ExportedTypes 86 Func<string, TypeDefinition> getCorLibTy = (name) => 87 { 88 return corLibrary.MainModule.GetType(name) ?? 89 corLibrary.MainModule.ExportedTypes.FirstOrDefault(x => x.FullName == name)?.Resolve(); 90 }; 91 _systemType = getCorLibTy("System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity 92 93 if (enableUnmangedFunctionPointerInject) 94 { 95 var unmanagedFunctionPointerAttribute = getCorLibTy("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute"); 96 _callingConventionType = getCorLibTy("System.Runtime.InteropServices.CallingConvention"); 97 _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType); 98 } 99 100 var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule"); 101 // bail if we can't find a reference, handled gracefully later 102 if (asmDef == null) 103 return; 104 105 var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute"); 106 _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First(); 107 } 108 109 } 110 111 public bool Run(AssemblyDefinition assemblyDefinition) 112 { 113 Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem); 114 115 var types = assemblyDefinition.MainModule.GetTypes().ToArray(); 116 foreach (var type in types) 117 { 118 CollectDelegateInvokesFromType(type); 119 } 120 121 return Finish(); 122 } 123 124 public void CollectDelegateInvokesFromType(TypeDefinition type) 125 { 126 foreach (var m in type.Methods) 127 { 128 if (m.HasBody) 129 { 130 CollectDelegateInvokes(m); 131 } 132 } 133 } 134 135 private bool ProcessUnmanagedAttributeFixups() 136 { 137 if (_unmanagedFunctionPointerAttributeCtorDef == null) 138 return false; 139 140 bool modified = false; 141 142 foreach (var kp in _needsNativeFunctionPointer) 143 { 144 var delegateType = kp.Key; 145 var instruction = kp.Value.instruction; 146 var method = kp.Value.method; 147 var delegateDef = delegateType.Resolve(); 148 149 var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName); 150 151 // If there is already an an attribute present 152 if (hasAttributeAlready!=null) 153 { 154 if (hasAttributeAlready.ConstructorArguments.Count==1) 155 { 156 var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value; 157 if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl) 158 { 159 if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches"); 160 } 161 else 162 { 163 // constructor with non cdecl calling convention 164 _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst."); 165 } 166 } 167 else 168 { 169 // Empty constructor which defaults to Winapi which is incompatable 170 _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst."); 171 } 172 continue; 173 } 174 175 var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef)); 176 attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl)); 177 delegateDef.CustomAttributes.Add(attribute); 178 modified = true; 179 } 180 181 return modified; 182 } 183 184 private bool ProcessIl2cppInvokeFixups() 185 { 186 if (_monoPInvokeAttributeCtorDef == null) 187 return false; 188 189 bool modified = false; 190 foreach (var invokeNeeded in _needsIl2cppInvoke) 191 { 192 var declaringType = invokeNeeded.Value; 193 var implementationMethod = invokeNeeded.Key; 194 195 // Unity requires a type parameter for the attributecallback 196 if (declaringType == null) 197 { 198 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}"); 199 continue; 200 } 201 202 var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef)); 203 attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType))); 204 implementationMethod.CustomAttributes.Add(attribute); 205 modified = true; 206 207 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to {implementationMethod}"); 208 } 209 210 return modified; 211 } 212 213 private bool ProcessFunctionPointerInvokes() 214 { 215 var madeChange = false; 216 foreach (var capturedData in _capturedSets) 217 { 218 var latePatchMethod = capturedData.Key; 219 var capturedList = capturedData.Value; 220 221 latePatchMethod.Body.SimplifyMacros(); // De-optimise short branches, since we will end up inserting instructions 222 223 foreach(var capturedInfo in capturedList) 224 { 225 var captured = capturedInfo.Captured; 226 var operand = capturedInfo.Operand; 227 228 if (captured.Count!=2) 229 { 230 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference"); 231 continue; 232 } 233 234 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine} latePatchMethod:{latePatchMethod}{Environment.NewLine} captureList:{capturedList}{Environment.NewLine} capture0:{captured[0]}{Environment.NewLine} operand:{operand}"); 235 236 var processor = latePatchMethod.Body.GetILProcessor(); 237 238 var genericContext = GenericContext.From(operand, operand.DeclaringType); 239 CallSite callsite; 240 try 241 { 242 callsite = new CallSite(genericContext.Resolve(operand.ReturnType)) 243 { 244 CallingConvention = MethodCallingConvention.C 245 }; 246 247 for (int oo = 0; oo < operand.Parameters.Count; oo++) 248 { 249 var param = operand.Parameters[oo]; 250 var ty = genericContext.Resolve(param.ParameterType); 251 callsite.Parameters.Add(new ParameterDefinition(param.Name, param.Attributes, ty)); 252 } 253 } 254 catch (NullReferenceException) 255 { 256 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Failed to resolve the generic context of `{operand}`"); 257 continue; 258 } 259 260 // Make sure everything is in order before we make a change 261 262 var originalGetInvoke = captured[0]; 263 264 if (originalGetInvoke.Operand is MethodReference mmr) 265 { 266 var genericMethodDef = mmr.Resolve(); 267 268 var genericInstanceType = mmr.DeclaringType as GenericInstanceType; 269 var genericInstanceDef = genericInstanceType.Resolve(); 270 271 // Locate the correct instance method - we know already at this point we have an instance of Function 272 MethodReference mr = default; 273 bool failed = true; 274 foreach (var m in genericInstanceDef.Methods) 275 { 276 if (m.FullName.Contains("get_Value")) 277 { 278 mr = m; 279 failed = false; 280 break; 281 } 282 } 283 if (failed) 284 { 285 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference"); 286 continue; 287 } 288 289 var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType) 290 { 291 HasThis = mr.HasThis, 292 ExplicitThis = mr.ExplicitThis, 293 CallingConvention = mr.CallingConvention 294 }; 295 foreach (var param in mr.Parameters) 296 newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType)); 297 foreach (var gparam in mr.GenericParameters) 298 newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef)); 299 var importRef = latePatchMethod.Module.ImportReference(newGenericRef); 300 var newMethodCall = processor.Create(OpCodes.Call, importRef); 301 302 // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch 303 //the branch doesn't get updated. 304 originalGetInvoke.OpCode = newMethodCall.OpCode; 305 originalGetInvoke.Operand = newMethodCall.Operand; 306 307 // Add local to capture result 308 var newLocal = new VariableDefinition(mr.ReturnType); 309 latePatchMethod.Body.Variables.Add(newLocal); 310 311 // Store result of get_Value 312 var storeInst = processor.Create(OpCodes.Stloc, newLocal); 313 processor.InsertAfter(originalGetInvoke, storeInst); 314 315 // Swap invoke with calli 316 var calli = processor.Create(OpCodes.Calli, callsite); 317 // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch 318 processor.Replace(captured[1], calli); 319 320 // Insert load local prior to calli 321 var loadValue = processor.Create(OpCodes.Ldloc, newLocal); 322 processor.InsertBefore(calli, loadValue); 323 324 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}"); 325 326 madeChange = true; 327 } 328 } 329 330 latePatchMethod.Body.OptimizeMacros(); // Re-optimise branches 331 } 332 return madeChange; 333 } 334 335 public bool Finish() 336 { 337 bool madeChange = false; 338 339 if (enableInvokeAttribute) 340 { 341 madeChange |= ProcessIl2cppInvokeFixups(); 342 } 343 344 if (enableUnmangedFunctionPointerInject) 345 { 346 madeChange |= ProcessUnmanagedAttributeFixups(); 347 } 348 349 if (enableCalliOptimisation) 350 { 351 madeChange |= ProcessFunctionPointerInvokes(); 352 } 353 354 return madeChange; 355 } 356 357 private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance) 358 { 359 methodInstance = methodRef?.DeclaringType as GenericInstanceType; 360 return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method); 361 } 362 363 private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method) 364 { 365 var methodInstance = methodRef?.DeclaringType as TypeReference; 366 return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method); 367 } 368 369 private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i) 370 { 371 if (i.OpCode == OpCodes.Call) 372 { 373 var genInstMethod = i.Operand as GenericInstanceMethod; 374 375 var isBurstCompilerCompileFunctionPointer = IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer"); 376 var isBurstFunctionPointerGetInvoke = IsBurstFunctionPointerMethod(i.Operand as MethodReference, "get_Invoke", out var instanceType); 377 if (!(isBurstCompilerCompileFunctionPointer || isBurstFunctionPointerGetInvoke)) return; 378 379 if (enableUnmangedFunctionPointerInject) 380 { 381 var delegateType = isBurstCompilerCompileFunctionPointer ? genInstMethod.GenericArguments[0].Resolve() : instanceType.GenericArguments[0].Resolve(); 382 // We check for null, since unfortunately it is possible that the call is wrapped inside 383 //another open delegate and we cannot determine the delegate type 384 if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType)) 385 { 386 _needsNativeFunctionPointer.Add(delegateType, (m, i)); 387 } 388 } 389 390 // No need to process further if its not a CompileFunctionPointer method 391 if (!isBurstCompilerCompileFunctionPointer) return; 392 393 if (enableInvokeAttribute) 394 { 395 // Currently only handles the following pre-pattern (which should cover most common uses) 396 // ldftn ... 397 // newobj ... 398 399 if (i.Previous?.OpCode != OpCodes.Newobj) 400 { 401 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}"); 402 return; 403 } 404 405 var newObj = i.Previous; 406 if (newObj.Previous?.OpCode != OpCodes.Ldftn) 407 { 408 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}"); 409 return; 410 } 411 412 var ldFtn = newObj.Previous; 413 414 // Determine the delegate type 415 var methodDefinition = newObj.Operand as MethodDefinition; 416 var declaringType = methodDefinition?.DeclaringType; 417 418 // Fetch the implementation method 419 var implementationMethod = ldFtn.Operand as MethodDefinition; 420 421 var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x => 422 (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName) 423 || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName)); 424 425 if (hasInvokeAlready != null) 426 { 427 if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}"); 428 return; 429 } 430 431 if (implementationMethod == null) 432 { 433 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}"); 434 return; 435 } 436 437 if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null) 438 { 439 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}"); 440 return; 441 } 442 443 // Need to add the custom attribute 444 if (!_needsIl2cppInvoke.ContainsKey(implementationMethod)) 445 { 446 _needsIl2cppInvoke.Add(implementationMethod, declaringType); 447 } 448 } 449 } 450 } 451 452 [Obsolete("Will be removed in a future Burst verison")] 453 public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i) 454 { 455 throw new NotImplementedException(); 456 } 457 458 private void CollectDelegateInvokes(MethodDefinition m) 459 { 460 if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject)) 461 return; 462 463 bool hitGetInvoke = false; 464 TypeDefinition delegateType = null; 465 List<Instruction> captured = null; 466 467 foreach (var inst in m.Body.Instructions) 468 { 469 if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}"); 470 471 // Check for a FunctionPointerT creation 472 if (enableUnmangedFunctionPointerInject || enableInvokeAttribute) 473 { 474 LocateFunctionPointerTCreation(m, inst); 475 } 476 477 if (enableCalliOptimisation) 478 { 479 if (!hitGetInvoke) 480 { 481 if (inst.OpCode != OpCodes.Call) continue; 482 if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue; 483 484 // At this point we have a call to a FunctionPointer.Invoke 485 hitGetInvoke = true; 486 487 delegateType = methodInstance.GenericArguments[0].Resolve(); 488 489 captured = new List<Instruction>(); 490 491 captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local 492 } 493 else 494 { 495 if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call)) 496 { 497 // Don't perform transform across blocks 498 hitGetInvoke = false; 499 } 500 else 501 { 502 if (inst.OpCode == OpCodes.Callvirt) 503 { 504 if (inst.Operand is MethodReference mref) 505 { 506 var method = mref.Resolve(); 507 508 if (method.DeclaringType == delegateType) 509 { 510 hitGetInvoke = false; 511 512 List<CaptureInformation> storage = null; 513 if (!_capturedSets.TryGetValue(m, out storage)) 514 { 515 storage = new List<CaptureInformation>(); 516 _capturedSets.Add(m, storage); 517 } 518 519 // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli 520 captured.Add(inst); 521 var captureInfo = new CaptureInformation { Captured = captured, Operand = mref }; 522 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}"); 523 storage.Add(captureInfo); 524 } 525 } 526 else 527 { 528 hitGetInvoke = false; 529 } 530 } 531 } 532 } 533 } 534 } 535 } 536 } 537}