A game about forced loneliness, made by TACStudios
1using System;
2using System.Collections.Generic;
3using System.IO;
4using System.Linq;
5using Burst.Compiler.IL.Syntax;
6using Mono.Cecil;
7using Mono.Cecil.Cil;
8using Mono.Cecil.Rocks;
9
10namespace zzzUnity.Burst.CodeGen
11{
12 /// <summary>
13 /// Transforms a direct invoke on a burst function pointer into an calli, avoiding the need to marshal the delegate back.
14 /// </summary>
15 internal class FunctionPointerInvokeTransform
16 {
17 private struct CaptureInformation
18 {
19 public MethodReference Operand;
20
21 public List<Instruction> Captured;
22 }
23
24 private Dictionary<TypeReference, (MethodDefinition method, Instruction instruction)> _needsNativeFunctionPointer;
25 private Dictionary<MethodDefinition, TypeReference> _needsIl2cppInvoke;
26 private Dictionary<MethodDefinition, List<CaptureInformation>> _capturedSets;
27 private MethodDefinition _monoPInvokeAttributeCtorDef;
28 private MethodDefinition _nativePInvokeAttributeCtorDef;
29 private MethodDefinition _unmanagedFunctionPointerAttributeCtorDef;
30 private TypeReference _burstFunctionPointerType;
31 private TypeReference _burstCompilerType;
32 private TypeReference _systemType;
33 private TypeReference _callingConventionType;
34
35 private LogDelegate _debugLog;
36 private int _logLevel;
37
38 private AssemblyResolver _loader;
39
40 private ErrorDiagnosticDelegate _errorReport;
41
42 public readonly static bool enableInvokeAttribute = true;
43 public readonly static bool enableCalliOptimisation = false; // For now only run the pass on dots player/tiny
44 public readonly static bool enableUnmangedFunctionPointerInject = true;
45
46 public FunctionPointerInvokeTransform(AssemblyResolver loader,ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0)
47 {
48 _loader = loader;
49
50 _needsNativeFunctionPointer = new Dictionary<TypeReference, (MethodDefinition, Instruction)>();
51 _needsIl2cppInvoke = new Dictionary<MethodDefinition, TypeReference>();
52 _capturedSets = new Dictionary<MethodDefinition, List<CaptureInformation>>();
53 _monoPInvokeAttributeCtorDef = null;
54 _unmanagedFunctionPointerAttributeCtorDef = null;
55 _nativePInvokeAttributeCtorDef = null; // Only present on DOTS_PLAYER
56 _burstFunctionPointerType = null;
57 _burstCompilerType = null;
58 _systemType = null;
59 _callingConventionType = null;
60 _debugLog = log;
61 _logLevel = logLevel;
62 _errorReport = error;
63 }
64
65 private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName)
66 {
67 if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result))
68 {
69 return result;
70 }
71 return null;
72 }
73
74 public void Initialize(AssemblyResolver loader, AssemblyDefinition assemblyDefinition, TypeSystem typeSystem)
75 {
76 if (_monoPInvokeAttributeCtorDef == null)
77 {
78 var burstAssembly = GetAsmDefinitionFromFile(loader, "Unity.Burst");
79
80 _burstFunctionPointerType = burstAssembly.MainModule.GetType("Unity.Burst.FunctionPointer`1");
81 _burstCompilerType = burstAssembly.MainModule.GetType("Unity.Burst.BurstCompiler");
82
83 var corLibrary = loader.Resolve(typeSystem.CoreLibrary as AssemblyNameReference);
84 // If the corLibrary is a redirecting assembly, then the type isn't present in Types
85 // and GetType() will therefore not find it, so instead we'll have to look it up in ExportedTypes
86 Func<string, TypeDefinition> getCorLibTy = (name) =>
87 {
88 return corLibrary.MainModule.GetType(name) ??
89 corLibrary.MainModule.ExportedTypes.FirstOrDefault(x => x.FullName == name)?.Resolve();
90 };
91 _systemType = getCorLibTy("System.Type"); // Only needed for MonoPInvokeCallback constructor in Unity
92
93 if (enableUnmangedFunctionPointerInject)
94 {
95 var unmanagedFunctionPointerAttribute = getCorLibTy("System.Runtime.InteropServices.UnmanagedFunctionPointerAttribute");
96 _callingConventionType = getCorLibTy("System.Runtime.InteropServices.CallingConvention");
97 _unmanagedFunctionPointerAttributeCtorDef = unmanagedFunctionPointerAttribute.GetConstructors().Single(c => c.Parameters.Count == 1 && c.Parameters[0].ParameterType.MetadataType == _callingConventionType.MetadataType);
98 }
99
100 var asmDef = GetAsmDefinitionFromFile(loader, "UnityEngine.CoreModule");
101 // bail if we can't find a reference, handled gracefully later
102 if (asmDef == null)
103 return;
104
105 var monoPInvokeAttribute = asmDef.MainModule.GetType("AOT.MonoPInvokeCallbackAttribute");
106 _monoPInvokeAttributeCtorDef = monoPInvokeAttribute.GetConstructors().First();
107 }
108
109 }
110
111 public bool Run(AssemblyDefinition assemblyDefinition)
112 {
113 Initialize(_loader, assemblyDefinition, assemblyDefinition.MainModule.TypeSystem);
114
115 var types = assemblyDefinition.MainModule.GetTypes().ToArray();
116 foreach (var type in types)
117 {
118 CollectDelegateInvokesFromType(type);
119 }
120
121 return Finish();
122 }
123
124 public void CollectDelegateInvokesFromType(TypeDefinition type)
125 {
126 foreach (var m in type.Methods)
127 {
128 if (m.HasBody)
129 {
130 CollectDelegateInvokes(m);
131 }
132 }
133 }
134
135 private bool ProcessUnmanagedAttributeFixups()
136 {
137 if (_unmanagedFunctionPointerAttributeCtorDef == null)
138 return false;
139
140 bool modified = false;
141
142 foreach (var kp in _needsNativeFunctionPointer)
143 {
144 var delegateType = kp.Key;
145 var instruction = kp.Value.instruction;
146 var method = kp.Value.method;
147 var delegateDef = delegateType.Resolve();
148
149 var hasAttributeAlready = delegateDef.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == _unmanagedFunctionPointerAttributeCtorDef.DeclaringType.FullName);
150
151 // If there is already an an attribute present
152 if (hasAttributeAlready!=null)
153 {
154 if (hasAttributeAlready.ConstructorArguments.Count==1)
155 {
156 var cc = (System.Runtime.InteropServices.CallingConvention)hasAttributeAlready.ConstructorArguments[0].Value;
157 if (cc == System.Runtime.InteropServices.CallingConvention.Cdecl)
158 {
159 if (_logLevel > 2) _debugLog?.Invoke($"UnmanagedAttributeFixups Skipping appending unmanagedFunctionPointerAttribute as already present aand calling convention matches");
160 }
161 else
162 {
163 // constructor with non cdecl calling convention
164 _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer(CallingConvention.{ Enum.GetName(typeof(System.Runtime.InteropServices.CallingConvention), cc) })]` please remove the attribute if you wish to use this function with Burst.");
165 }
166 }
167 else
168 {
169 // Empty constructor which defaults to Winapi which is incompatable
170 _errorReport(method, instruction, $"BurstCompiler.CompileFunctionPointer is only compatible with cdecl calling convention, this delegate type already has `[UnmanagedFunctionPointer]` please remove the attribute if you wish to use this function with Burst.");
171 }
172 continue;
173 }
174
175 var attribute = new CustomAttribute(delegateType.Module.ImportReference(_unmanagedFunctionPointerAttributeCtorDef));
176 attribute.ConstructorArguments.Add(new CustomAttributeArgument(delegateType.Module.ImportReference(_callingConventionType), System.Runtime.InteropServices.CallingConvention.Cdecl));
177 delegateDef.CustomAttributes.Add(attribute);
178 modified = true;
179 }
180
181 return modified;
182 }
183
184 private bool ProcessIl2cppInvokeFixups()
185 {
186 if (_monoPInvokeAttributeCtorDef == null)
187 return false;
188
189 bool modified = false;
190 foreach (var invokeNeeded in _needsIl2cppInvoke)
191 {
192 var declaringType = invokeNeeded.Value;
193 var implementationMethod = invokeNeeded.Key;
194
195 // Unity requires a type parameter for the attributecallback
196 if (declaringType == null)
197 {
198 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing declaringType for {implementationMethod}");
199 continue;
200 }
201
202 var attribute = new CustomAttribute(implementationMethod.Module.ImportReference(_monoPInvokeAttributeCtorDef));
203 attribute.ConstructorArguments.Add(new CustomAttributeArgument(implementationMethod.Module.ImportReference(_systemType), implementationMethod.Module.ImportReference(declaringType)));
204 implementationMethod.CustomAttributes.Add(attribute);
205 modified = true;
206
207 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Added InvokeCallbackAttribute to {implementationMethod}");
208 }
209
210 return modified;
211 }
212
213 private bool ProcessFunctionPointerInvokes()
214 {
215 var madeChange = false;
216 foreach (var capturedData in _capturedSets)
217 {
218 var latePatchMethod = capturedData.Key;
219 var capturedList = capturedData.Value;
220
221 latePatchMethod.Body.SimplifyMacros(); // De-optimise short branches, since we will end up inserting instructions
222
223 foreach(var capturedInfo in capturedList)
224 {
225 var captured = capturedInfo.Captured;
226 var operand = capturedInfo.Operand;
227
228 if (captured.Count!=2)
229 {
230 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: expected 2 instructions - Unable to optimise this reference");
231 continue;
232 }
233
234 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish:{Environment.NewLine} latePatchMethod:{latePatchMethod}{Environment.NewLine} captureList:{capturedList}{Environment.NewLine} capture0:{captured[0]}{Environment.NewLine} operand:{operand}");
235
236 var processor = latePatchMethod.Body.GetILProcessor();
237
238 var genericContext = GenericContext.From(operand, operand.DeclaringType);
239 CallSite callsite;
240 try
241 {
242 callsite = new CallSite(genericContext.Resolve(operand.ReturnType))
243 {
244 CallingConvention = MethodCallingConvention.C
245 };
246
247 for (int oo = 0; oo < operand.Parameters.Count; oo++)
248 {
249 var param = operand.Parameters[oo];
250 var ty = genericContext.Resolve(param.ParameterType);
251 callsite.Parameters.Add(new ParameterDefinition(param.Name, param.Attributes, ty));
252 }
253 }
254 catch (NullReferenceException)
255 {
256 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Failed to resolve the generic context of `{operand}`");
257 continue;
258 }
259
260 // Make sure everything is in order before we make a change
261
262 var originalGetInvoke = captured[0];
263
264 if (originalGetInvoke.Operand is MethodReference mmr)
265 {
266 var genericMethodDef = mmr.Resolve();
267
268 var genericInstanceType = mmr.DeclaringType as GenericInstanceType;
269 var genericInstanceDef = genericInstanceType.Resolve();
270
271 // Locate the correct instance method - we know already at this point we have an instance of Function
272 MethodReference mr = default;
273 bool failed = true;
274 foreach (var m in genericInstanceDef.Methods)
275 {
276 if (m.FullName.Contains("get_Value"))
277 {
278 mr = m;
279 failed = false;
280 break;
281 }
282 }
283 if (failed)
284 {
285 _debugLog?.Invoke($"FunctionPtrInvoke.Finish: failed to locate get_Value method on {genericInstanceDef} - Unable to optimise this reference");
286 continue;
287 }
288
289 var newGenericRef = new MethodReference(mr.Name, mr.ReturnType, genericInstanceType)
290 {
291 HasThis = mr.HasThis,
292 ExplicitThis = mr.ExplicitThis,
293 CallingConvention = mr.CallingConvention
294 };
295 foreach (var param in mr.Parameters)
296 newGenericRef.Parameters.Add(new ParameterDefinition(param.ParameterType));
297 foreach (var gparam in mr.GenericParameters)
298 newGenericRef.GenericParameters.Add(new GenericParameter(gparam.Name, newGenericRef));
299 var importRef = latePatchMethod.Module.ImportReference(newGenericRef);
300 var newMethodCall = processor.Create(OpCodes.Call, importRef);
301
302 // Replace get_invoke with get_Value - Don't use replace though as if the original call is target of a branch
303 //the branch doesn't get updated.
304 originalGetInvoke.OpCode = newMethodCall.OpCode;
305 originalGetInvoke.Operand = newMethodCall.Operand;
306
307 // Add local to capture result
308 var newLocal = new VariableDefinition(mr.ReturnType);
309 latePatchMethod.Body.Variables.Add(newLocal);
310
311 // Store result of get_Value
312 var storeInst = processor.Create(OpCodes.Stloc, newLocal);
313 processor.InsertAfter(originalGetInvoke, storeInst);
314
315 // Swap invoke with calli
316 var calli = processor.Create(OpCodes.Calli, callsite);
317 // We can use replace here, since we already checked this is in the same Basic Block, and thus can't be target of a branch
318 processor.Replace(captured[1], calli);
319
320 // Insert load local prior to calli
321 var loadValue = processor.Create(OpCodes.Ldloc, newLocal);
322 processor.InsertBefore(calli, loadValue);
323
324 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.Finish: Optimised {originalGetInvoke} with {newMethodCall}");
325
326 madeChange = true;
327 }
328 }
329
330 latePatchMethod.Body.OptimizeMacros(); // Re-optimise branches
331 }
332 return madeChange;
333 }
334
335 public bool Finish()
336 {
337 bool madeChange = false;
338
339 if (enableInvokeAttribute)
340 {
341 madeChange |= ProcessIl2cppInvokeFixups();
342 }
343
344 if (enableUnmangedFunctionPointerInject)
345 {
346 madeChange |= ProcessUnmanagedAttributeFixups();
347 }
348
349 if (enableCalliOptimisation)
350 {
351 madeChange |= ProcessFunctionPointerInvokes();
352 }
353
354 return madeChange;
355 }
356
357 private bool IsBurstFunctionPointerMethod(MethodReference methodRef, string method, out GenericInstanceType methodInstance)
358 {
359 methodInstance = methodRef?.DeclaringType as GenericInstanceType;
360 return (methodInstance != null && methodInstance.ElementType.FullName == _burstFunctionPointerType.FullName && methodRef.Name == method);
361 }
362
363 private bool IsBurstCompilerMethod(GenericInstanceMethod methodRef, string method)
364 {
365 var methodInstance = methodRef?.DeclaringType as TypeReference;
366 return (methodInstance != null && methodInstance.FullName == _burstCompilerType.FullName && methodRef.Name == method);
367 }
368
369 private void LocateFunctionPointerTCreation(MethodDefinition m, Instruction i)
370 {
371 if (i.OpCode == OpCodes.Call)
372 {
373 var genInstMethod = i.Operand as GenericInstanceMethod;
374
375 var isBurstCompilerCompileFunctionPointer = IsBurstCompilerMethod(genInstMethod, "CompileFunctionPointer");
376 var isBurstFunctionPointerGetInvoke = IsBurstFunctionPointerMethod(i.Operand as MethodReference, "get_Invoke", out var instanceType);
377 if (!(isBurstCompilerCompileFunctionPointer || isBurstFunctionPointerGetInvoke)) return;
378
379 if (enableUnmangedFunctionPointerInject)
380 {
381 var delegateType = isBurstCompilerCompileFunctionPointer ? genInstMethod.GenericArguments[0].Resolve() : instanceType.GenericArguments[0].Resolve();
382 // We check for null, since unfortunately it is possible that the call is wrapped inside
383 //another open delegate and we cannot determine the delegate type
384 if (delegateType != null && !_needsNativeFunctionPointer.ContainsKey(delegateType))
385 {
386 _needsNativeFunctionPointer.Add(delegateType, (m, i));
387 }
388 }
389
390 // No need to process further if its not a CompileFunctionPointer method
391 if (!isBurstCompilerCompileFunctionPointer) return;
392
393 if (enableInvokeAttribute)
394 {
395 // Currently only handles the following pre-pattern (which should cover most common uses)
396 // ldftn ...
397 // newobj ...
398
399 if (i.Previous?.OpCode != OpCodes.Newobj)
400 {
401 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding NewObj {i.Previous}");
402 return;
403 }
404
405 var newObj = i.Previous;
406 if (newObj.Previous?.OpCode != OpCodes.Ldftn)
407 {
408 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to not finding LdFtn {newObj.Previous}");
409 return;
410 }
411
412 var ldFtn = newObj.Previous;
413
414 // Determine the delegate type
415 var methodDefinition = newObj.Operand as MethodDefinition;
416 var declaringType = methodDefinition?.DeclaringType;
417
418 // Fetch the implementation method
419 var implementationMethod = ldFtn.Operand as MethodDefinition;
420
421 var hasInvokeAlready = implementationMethod?.CustomAttributes.FirstOrDefault(x =>
422 (x.AttributeType.FullName == _monoPInvokeAttributeCtorDef.DeclaringType.FullName)
423 || (_nativePInvokeAttributeCtorDef != null && x.AttributeType.FullName == _nativePInvokeAttributeCtorDef.DeclaringType.FullName));
424
425 if (hasInvokeAlready != null)
426 {
427 if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Skipping appending Callback Attribute as already present {hasInvokeAlready}");
428 return;
429 }
430
431 if (implementationMethod == null)
432 {
433 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing method from {ldFtn} {ldFtn.Operand}");
434 return;
435 }
436
437 if (implementationMethod.CustomAttributes.FirstOrDefault(x => x.Constructor.DeclaringType.Name == "BurstCompileAttribute") == null)
438 {
439 _debugLog?.Invoke($"FunctionPtrInvoke.LocateFunctionPointerTCreation: Unable to automatically append CallbackAttribute due to missing burst attribute from {implementationMethod}");
440 return;
441 }
442
443 // Need to add the custom attribute
444 if (!_needsIl2cppInvoke.ContainsKey(implementationMethod))
445 {
446 _needsIl2cppInvoke.Add(implementationMethod, declaringType);
447 }
448 }
449 }
450 }
451
452 [Obsolete("Will be removed in a future Burst verison")]
453 public bool IsInstructionForFunctionPointerInvoke(MethodDefinition m, Instruction i)
454 {
455 throw new NotImplementedException();
456 }
457
458 private void CollectDelegateInvokes(MethodDefinition m)
459 {
460 if (!(enableCalliOptimisation || enableInvokeAttribute || enableUnmangedFunctionPointerInject))
461 return;
462
463 bool hitGetInvoke = false;
464 TypeDefinition delegateType = null;
465 List<Instruction> captured = null;
466
467 foreach (var inst in m.Body.Instructions)
468 {
469 if (_logLevel > 2) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: CurrentInstruction {inst} {inst.Operand}");
470
471 // Check for a FunctionPointerT creation
472 if (enableUnmangedFunctionPointerInject || enableInvokeAttribute)
473 {
474 LocateFunctionPointerTCreation(m, inst);
475 }
476
477 if (enableCalliOptimisation)
478 {
479 if (!hitGetInvoke)
480 {
481 if (inst.OpCode != OpCodes.Call) continue;
482 if (!IsBurstFunctionPointerMethod(inst.Operand as MethodReference, "get_Invoke", out var methodInstance)) continue;
483
484 // At this point we have a call to a FunctionPointer.Invoke
485 hitGetInvoke = true;
486
487 delegateType = methodInstance.GenericArguments[0].Resolve();
488
489 captured = new List<Instruction>();
490
491 captured.Add(inst); // Capture the get_invoke, we will swap this for get_value and a store to local
492 }
493 else
494 {
495 if (!(inst.OpCode.FlowControl == FlowControl.Next || inst.OpCode.FlowControl == FlowControl.Call))
496 {
497 // Don't perform transform across blocks
498 hitGetInvoke = false;
499 }
500 else
501 {
502 if (inst.OpCode == OpCodes.Callvirt)
503 {
504 if (inst.Operand is MethodReference mref)
505 {
506 var method = mref.Resolve();
507
508 if (method.DeclaringType == delegateType)
509 {
510 hitGetInvoke = false;
511
512 List<CaptureInformation> storage = null;
513 if (!_capturedSets.TryGetValue(m, out storage))
514 {
515 storage = new List<CaptureInformation>();
516 _capturedSets.Add(m, storage);
517 }
518
519 // Capture the invoke - which we will swap for a load local (stored from the get_value) and a calli
520 captured.Add(inst);
521 var captureInfo = new CaptureInformation { Captured = captured, Operand = mref };
522 if (_logLevel > 1) _debugLog?.Invoke($"FunctionPtrInvoke.CollectDelegateInvokes: captureInfo:{captureInfo}{Environment.NewLine}capture0{captured[0]}");
523 storage.Add(captureInfo);
524 }
525 }
526 else
527 {
528 hitGetInvoke = false;
529 }
530 }
531 }
532 }
533 }
534 }
535 }
536 }
537}