A game about forced loneliness, made by TACStudios
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using Mono.Cecil;
6using Mono.Cecil.Cil;
7
8namespace zzzUnity.Burst.CodeGen
9{
10 /// <summary>
11 /// Main class for post processing assemblies. The post processing is currently performing:
12 /// - Replace C# call from C# to Burst functions with attributes [BurstCompile] to a call to the compiled Burst function
13 /// In both editor and standalone scenarios. For DOTS Runtime, this is done differently at BclApp level by patching
14 /// DllImport.
15 /// - Replace calls to `SharedStatic.GetOrCreate` with `SharedStatic.GetOrCreateUnsafe`, and calculate the hashes during ILPP time
16 /// rather than in static constructors at runtime.
17 /// </summary>
18 internal class ILPostProcessing
19 {
20 private AssemblyDefinition _burstAssembly;
21 private MethodReference _burstCompilerIsEnabledMethodDefinition;
22 private MethodReference _burstCompilerCompileFunctionPointer;
23 private FieldReference _burstCompilerOptionsField;
24 private TypeReference _burstCompilerOptionsType;
25 private TypeReference _functionPointerType;
26 private MethodReference _functionPointerGetValue;
27 private MethodReference _burstDiscardAttributeConstructor;
28 private TypeSystem _typeSystem;
29 private TypeReference _systemDelegateType;
30 private TypeReference _systemASyncCallbackType;
31 private TypeReference _systemIASyncResultType;
32 private AssemblyDefinition _assemblyDefinition;
33 private bool _modified;
34#if !UNITY_DOTSPLAYER
35 private bool _containsDirectCall;
36#endif
37 private readonly StringBuilder _builder = new StringBuilder(1024);
38 private readonly List<Instruction> _instructionsToReplace = new List<Instruction>(4);
39
40 public const string PostfixManaged = "$BurstManaged";
41 private const string PostfixBurstDirectCall = "$BurstDirectCall";
42 private const string PostfixBurstDelegate = "$PostfixBurstDelegate";
43 private const string GetFunctionPointerName = "GetFunctionPointer";
44 private const string GetFunctionPointerDiscardName = "GetFunctionPointerDiscard";
45 private const string InvokeName = "Invoke";
46
47 public ILPostProcessing(AssemblyResolver loader, bool isForEditor, ErrorDiagnosticDelegate error, LogDelegate log = null, int logLevel = 0, bool skipInitializeOnLoad = false)
48 {
49 _skipInitializeOnLoad = skipInitializeOnLoad;
50 Loader = loader;
51 IsForEditor = isForEditor;
52 }
53
54 public bool _skipInitializeOnLoad;
55
56 public bool IsForEditor { get; private set; }
57
58 private AssemblyResolver Loader { get; }
59
60 public bool Run(AssemblyDefinition assemblyDefinition)
61 {
62 _assemblyDefinition = assemblyDefinition;
63 _typeSystem = assemblyDefinition.MainModule.TypeSystem;
64
65 _modified = false;
66 var types = assemblyDefinition.MainModule.GetTypes().ToArray();
67 foreach (var type in types)
68 {
69 ProcessType(type);
70 }
71
72#if !UNITY_DOTSPLAYER
73 if (_containsDirectCall)
74 {
75 GenerateInitializeOnLoadMethod();
76 }
77#endif
78
79 return _modified;
80 }
81
82 private void GenerateInitializeOnLoadMethod()
83 {
84 // This method is needed to ensure that BurstCompiler.Options is initialized on the main thread,
85 // before any direct call methods are called on a background thread.
86
87 // [UnityEngine.RuntimeInitializeOnLoadMethod(UnityEngine.RuntimeInitializeLoadType.AfterAssembliesLoaded)]
88 // [UnityEditor.InitializeOnLoadMethod] // When its an editor assembly
89 // private static void Initialize()
90 // {
91 // var _ = BurstCompiler.Options;
92 // }
93 const string initializeOnLoadClassName = "$BurstDirectCallInitializer";
94 var initializeOnLoadClass = _assemblyDefinition.MainModule.Types.FirstOrDefault(x => x.Name == initializeOnLoadClassName);
95 if (initializeOnLoadClass != null)
96 {
97 // If there's already a class with this name, remove it,
98 // This would mean that we're postprocessing an already-postprocessed assembly;
99 // I don't think that ever happens, but no sense in breaking if it does.
100 _assemblyDefinition.MainModule.Types.Remove(initializeOnLoadClass);
101 }
102 initializeOnLoadClass = new TypeDefinition(
103 "",
104 initializeOnLoadClassName,
105 TypeAttributes.NotPublic |
106 TypeAttributes.AutoLayout |
107 TypeAttributes.AnsiClass |
108 TypeAttributes.Abstract |
109 TypeAttributes.Sealed |
110 TypeAttributes.BeforeFieldInit)
111 {
112 BaseType = _typeSystem.Object
113 };
114 _assemblyDefinition.MainModule.Types.Add(initializeOnLoadClass);
115 var initializeOnLoadMethod = new MethodDefinition("Initialize", MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, _typeSystem.Void)
116 {
117 ImplAttributes = MethodImplAttributes.IL | MethodImplAttributes.Managed,
118 DeclaringType = initializeOnLoadClass
119 };
120
121 initializeOnLoadMethod.Body.Variables.Add(new VariableDefinition(_burstCompilerOptionsType));
122
123 var processor = initializeOnLoadMethod.Body.GetILProcessor();
124 processor.Emit(OpCodes.Ldsfld, _burstCompilerOptionsField);
125 processor.Emit(OpCodes.Stloc_0);
126 processor.Emit(OpCodes.Ret);
127 initializeOnLoadClass.Methods.Add(FixDebugInformation(initializeOnLoadMethod));
128
129 var attribute = new CustomAttribute(_unityEngineInitializeOnLoadAttributeCtor);
130 attribute.ConstructorArguments.Add(new CustomAttributeArgument(_unityEngineRuntimeInitializeLoadType, _unityEngineRuntimeInitializeLoadAfterAssemblies.Constant));
131 initializeOnLoadMethod.CustomAttributes.Add(attribute);
132
133 if (IsForEditor && !_skipInitializeOnLoad)
134 {
135 // Need to ensure the editor tag for initialize on load is present, otherwise edit mode tests will not call Initialize
136 attribute = new CustomAttribute(_unityEditorInitilizeOnLoadAttributeCtor);
137 initializeOnLoadMethod.CustomAttributes.Add(attribute);
138 }
139 }
140
141 private static bool CanComputeCompileTimeHash(TypeReference typeRef)
142 {
143 if (typeRef.ContainsGenericParameter)
144 {
145 return false;
146 }
147
148 var assemblyNameReference = typeRef.Scope as AssemblyNameReference ?? typeRef.Module.Assembly?.Name;
149
150 if (assemblyNameReference == null)
151 {
152 return false;
153 }
154
155 switch (assemblyNameReference.Name)
156 {
157 case "netstandard":
158 case "mscorlib":
159 return false;
160 }
161
162 return true;
163 }
164
165 private void ProcessType(TypeDefinition type)
166 {
167 if (!type.HasGenericParameters && TryGetBurstCompileAttribute(type, out _))
168 {
169 // Make a copy because we are going to modify it
170 var methodCount = type.Methods.Count;
171 for (var j = 0; j < methodCount; j++)
172 {
173 var method = type.Methods[j];
174 if (!method.IsStatic || method.HasGenericParameters || !TryGetBurstCompileAttribute(method, out var methodBurstCompileAttribute)) continue;
175
176 bool isDirectCallDisabled = false;
177 bool foundProperty = false;
178 if (methodBurstCompileAttribute.HasProperties)
179 {
180 foreach (var property in methodBurstCompileAttribute.Properties)
181 {
182 if (property.Name == "DisableDirectCall")
183 {
184 isDirectCallDisabled = (bool)property.Argument.Value;
185 foundProperty = true;
186 break;
187 }
188 }
189 }
190
191 // If the method doesn't have a direct call specified, try the assembly level, do one last check for any assembly level [BurstCompile] instead.
192 if (foundProperty == false && TryGetBurstCompileAttribute(method.Module.Assembly, out var assemblyBurstCompileAttribute))
193 {
194 if (assemblyBurstCompileAttribute.HasProperties)
195 {
196 foreach (var property in assemblyBurstCompileAttribute.Properties)
197 {
198 if (property.Name == "DisableDirectCall")
199 {
200 isDirectCallDisabled = (bool)property.Argument.Value;
201 break;
202 }
203 }
204 }
205 }
206
207 foreach (var customAttribute in method.CustomAttributes)
208 {
209 if (customAttribute.AttributeType.FullName == "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute")
210 {
211 // Can't / shouldn't enable direct call for [UnmanagedCallersOnly] methods -
212 // these can't be called from managed code.
213 isDirectCallDisabled = true;
214 break;
215 }
216 }
217
218#if !UNITY_DOTSPLAYER // Direct call is not Supported for dots runtime via this pre-processor, its handled elsewhere, this code assumes a Unity Editor based burst
219 if (!isDirectCallDisabled)
220 {
221 if (_burstAssembly == null)
222 {
223 var resolved = methodBurstCompileAttribute.Constructor.DeclaringType.Resolve();
224 InitializeBurstAssembly(resolved.Module.Assembly);
225 }
226
227 ProcessMethodForDirectCall(method);
228 _modified = true;
229 _containsDirectCall = true;
230 }
231#endif
232 }
233 }
234
235 if (TypeHasSharedStaticInIt(type))
236 {
237 foreach (var method in type.Methods)
238 {
239 // Skip anything that isn't the static constructor.
240 if (method.Name != ".cctor")
241 {
242 continue;
243 }
244
245 try
246 {
247#if DEBUG
248 if (_instructionsToReplace.Count != 0)
249 {
250 throw new InvalidOperationException("Instructions to replace wasn't cleared properly!");
251 }
252#endif
253
254 foreach (var instruction in method.Body.Instructions)
255 {
256 // Skip anything that isn't a call.
257 if (instruction.OpCode != OpCodes.Call)
258 {
259 continue;
260 }
261
262 var calledMethod = (MethodReference)instruction.Operand;
263
264 if (calledMethod.Name != "GetOrCreate")
265 {
266 continue;
267 }
268
269 // Skip anything that isn't member of the `SharedStatic` class.
270 if (!TypeIsSharedStatic(calledMethod.DeclaringType))
271 {
272 continue;
273 }
274
275 // We only handle the `GetOrCreate` calls with a single parameter (the alignment).
276 if (calledMethod.Parameters.Count != 1)
277 {
278 continue;
279 }
280
281 // We only post-process the generic versions of `GetOrCreate`.
282 if (!(calledMethod is GenericInstanceMethod genericInstanceMethod))
283 {
284 continue;
285 }
286
287 var atLeastOneArgumentCanBeComputed = false;
288
289 foreach (var genericArgument in genericInstanceMethod.GenericArguments)
290 {
291 if (CanComputeCompileTimeHash(genericArgument))
292 {
293 atLeastOneArgumentCanBeComputed = true;
294 }
295 }
296
297 // We cannot post-process a shared static with all arguments being open generic.
298 // We cannot post-process a shared static where all of its types are in core libraries.
299 if (!atLeastOneArgumentCanBeComputed)
300 {
301 continue;
302 }
303
304 _instructionsToReplace.Add(instruction);
305 }
306
307 if (_instructionsToReplace.Count > 0)
308 {
309 _modified = true;
310 }
311
312 foreach (var instruction in _instructionsToReplace)
313 {
314 var calledMethod = (GenericInstanceMethod)instruction.Operand;
315
316 var hashCode64 = CalculateHashCode64(calledMethod.GenericArguments[0]);
317
318 long subHashCode64 = 0;
319
320 var useCalculatedHashCode = true;
321 var useCalculatedSubHashCode = true;
322
323 if (calledMethod.GenericArguments.Count == 2)
324 {
325 subHashCode64 = CalculateHashCode64(calledMethod.GenericArguments[1]);
326
327 useCalculatedHashCode = CanComputeCompileTimeHash(calledMethod.GenericArguments[0]);
328 useCalculatedSubHashCode = CanComputeCompileTimeHash(calledMethod.GenericArguments[1]);
329 }
330
331#if DEBUG
332 if (!useCalculatedHashCode && !useCalculatedSubHashCode)
333 {
334 throw new InvalidOperationException("Cannot replace when both hashes are invalid!");
335 }
336#endif
337
338 var methodToCall = "GetOrCreateUnsafe";
339 TypeReference genericArgument = null;
340
341 if (!useCalculatedHashCode)
342 {
343 methodToCall = "GetOrCreatePartiallyUnsafeWithSubHashCode";
344 genericArgument = calledMethod.GenericArguments[0];
345 }
346 else if (!useCalculatedSubHashCode)
347 {
348 methodToCall = "GetOrCreatePartiallyUnsafeWithHashCode";
349 genericArgument = calledMethod.GenericArguments[1];
350 }
351
352 var getOrCreateUnsafe = _assemblyDefinition.MainModule.ImportReference(
353 calledMethod.DeclaringType.Resolve().Methods.First(m => m.Name == methodToCall));
354
355 getOrCreateUnsafe.DeclaringType = calledMethod.DeclaringType;
356
357 if (genericArgument != null)
358 {
359 var genericInstanceMethod = new GenericInstanceMethod(getOrCreateUnsafe);
360
361 genericInstanceMethod.GenericArguments.Add(genericArgument);
362
363 getOrCreateUnsafe = genericInstanceMethod;
364 }
365
366 var processor = method.Body.GetILProcessor();
367
368 if (useCalculatedHashCode)
369 {
370 processor.InsertBefore(instruction, processor.Create(OpCodes.Ldc_I8, hashCode64));
371 }
372
373 if (useCalculatedSubHashCode)
374 {
375 processor.InsertBefore(instruction, processor.Create(OpCodes.Ldc_I8, subHashCode64));
376 }
377
378 processor.Replace(instruction, processor.Create(OpCodes.Call, getOrCreateUnsafe));
379 }
380 }
381 finally
382 {
383 _instructionsToReplace.Clear();
384 }
385 }
386 }
387 }
388
389 // WARNING: This **must** be kept in sync with the definition in BurstRuntime.cs!
390 private static long HashStringWithFNV1A64(string text)
391 {
392 // Using http://www.isthe.com/chongo/tech/comp/fnv/index.html#FNV-1a
393 // with basis and prime:
394 const ulong offsetBasis = 14695981039346656037;
395 const ulong prime = 1099511628211;
396
397 ulong result = offsetBasis;
398
399 foreach (var c in text)
400 {
401 result = prime * (result ^ (byte)(c & 255));
402 result = prime * (result ^ (byte)(c >> 8));
403 }
404
405 return (long)result;
406 }
407
408 private long CalculateHashCode64(TypeReference type)
409 {
410 try
411 {
412#if DEBUG
413 if (_builder.Length != 0)
414 {
415 throw new InvalidOperationException("StringBuilder wasn't cleared properly!");
416 }
417#endif
418
419 type.BuildAssemblyQualifiedName(_builder);
420 return HashStringWithFNV1A64(_builder.ToString());
421 }
422 finally
423 {
424 _builder.Clear();
425 }
426 }
427
428 private static bool TypeIsSharedStatic(TypeReference typeRef)
429 {
430 if (typeRef.Namespace != "Unity.Burst")
431 {
432 return false;
433 }
434
435 if (typeRef.Name != "SharedStatic`1")
436 {
437 return false;
438 }
439
440 return true;
441 }
442
443 private static bool TypeHasSharedStaticInIt(TypeDefinition typeDef)
444 {
445 foreach (var field in typeDef.Fields)
446 {
447 if (TypeIsSharedStatic(field.FieldType))
448 {
449 return true;
450 }
451 }
452
453 return false;
454 }
455
456 private TypeDefinition InjectDelegate(TypeDefinition declaringType, string originalName, MethodDefinition managed, string uniqueSuffix)
457 {
458 var injectedDelegateType = new TypeDefinition(declaringType.Namespace, $"{originalName}{uniqueSuffix}{PostfixBurstDelegate}",
459 TypeAttributes.NestedPublic |
460 TypeAttributes.AutoLayout |
461 TypeAttributes.AnsiClass |
462 TypeAttributes.Sealed
463 )
464 {
465 DeclaringType = declaringType,
466 BaseType = _systemDelegateType
467 };
468
469 declaringType.NestedTypes.Add(injectedDelegateType);
470
471 {
472 var constructor = new MethodDefinition(".ctor",
473 MethodAttributes.Public |
474 MethodAttributes.HideBySig |
475 MethodAttributes.SpecialName |
476 MethodAttributes.RTSpecialName,
477 _typeSystem.Void)
478 {
479 HasThis = true,
480 IsManaged = true,
481 IsRuntime = true,
482 DeclaringType = injectedDelegateType
483 };
484
485 constructor.Parameters.Add(new ParameterDefinition(_typeSystem.Object));
486 constructor.Parameters.Add(new ParameterDefinition(_typeSystem.IntPtr));
487 injectedDelegateType.Methods.Add(constructor);
488 }
489
490 {
491 var invoke = new MethodDefinition("Invoke",
492 MethodAttributes.Public |
493 MethodAttributes.HideBySig |
494 MethodAttributes.NewSlot |
495 MethodAttributes.Virtual,
496 managed.ReturnType)
497 {
498 HasThis = true,
499 IsManaged = true,
500 IsRuntime = true,
501 DeclaringType = injectedDelegateType
502 };
503
504 foreach (var parameter in managed.Parameters)
505 {
506 invoke.Parameters.Add(parameter);
507 }
508
509 injectedDelegateType.Methods.Add(invoke);
510 }
511
512 {
513 var beginInvoke = new MethodDefinition("BeginInvoke",
514 MethodAttributes.Public |
515 MethodAttributes.HideBySig |
516 MethodAttributes.NewSlot |
517 MethodAttributes.Virtual,
518 _systemIASyncResultType)
519 {
520 HasThis = true,
521 IsManaged = true,
522 IsRuntime = true,
523 DeclaringType = injectedDelegateType
524 };
525
526 foreach (var parameter in managed.Parameters)
527 {
528 beginInvoke.Parameters.Add(parameter);
529 }
530
531 beginInvoke.Parameters.Add(new ParameterDefinition(_systemASyncCallbackType));
532 beginInvoke.Parameters.Add(new ParameterDefinition(_typeSystem.Object));
533
534 injectedDelegateType.Methods.Add(beginInvoke);
535 }
536
537 {
538 var endInvoke = new MethodDefinition("EndInvoke",
539 MethodAttributes.Public |
540 MethodAttributes.HideBySig |
541 MethodAttributes.NewSlot |
542 MethodAttributes.Virtual,
543 managed.ReturnType)
544 {
545 HasThis = true,
546 IsManaged = true,
547 IsRuntime = true,
548 DeclaringType = injectedDelegateType
549 };
550
551 endInvoke.Parameters.Add(new ParameterDefinition(_systemIASyncResultType));
552
553 injectedDelegateType.Methods.Add(endInvoke);
554 }
555
556 return injectedDelegateType;
557 }
558
559 private MethodDefinition CreateGetFunctionPointerDiscardMethod(TypeDefinition cls, FieldDefinition pointerField, MethodDefinition targetMethod, TypeDefinition injectedDelegate)
560 {
561 var genericCompileFunctionPointer = new GenericInstanceMethod(_burstCompilerCompileFunctionPointer);
562 genericCompileFunctionPointer.GenericArguments.Add(injectedDelegate);
563
564 var genericFunctionPointerType = new GenericInstanceType(_functionPointerType);
565 genericFunctionPointerType.GenericArguments.Add(injectedDelegate);
566
567 var genericGetValue = new MethodReference(_functionPointerGetValue.Name, _functionPointerGetValue.ReturnType, genericFunctionPointerType);
568
569 foreach (var p in _functionPointerGetValue.Parameters)
570 {
571 genericGetValue.Parameters.Add(new ParameterDefinition(p.Name, p.Attributes, p.ParameterType));
572 }
573
574 genericGetValue.HasThis = _functionPointerGetValue.HasThis;
575 genericGetValue.MetadataToken = _functionPointerGetValue.MetadataToken;
576
577 /*var genericGetValue = new Mono.Cecil.GenericInstanceMethod(_functionPointerGetValue)
578 {
579 DeclaringType = genericFunctionPointerType
580 };*/
581
582 // Create GetFunctionPointerDiscard method:
583 //
584 // [BurstDiscard]
585 // public static void GetFunctionPointerDiscard(ref IntPtr ptr) {
586 // if (Pointer == null) {
587 // Pointer = BurstCompiler.CompileFunctionPointer<InjectedDelegate>(d);
588 // }
589 //
590 // ptr = Pointer
591 // }
592 var getFunctionPointerDiscardMethod = new MethodDefinition(GetFunctionPointerDiscardName, MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, _typeSystem.Void)
593 {
594 ImplAttributes = MethodImplAttributes.IL | MethodImplAttributes.Managed,
595 DeclaringType = cls
596 };
597
598 getFunctionPointerDiscardMethod.Body.Variables.Add(new VariableDefinition(genericFunctionPointerType));
599
600 getFunctionPointerDiscardMethod.Parameters.Add(new ParameterDefinition(new ByReferenceType(_typeSystem.IntPtr)));
601
602 var processor = getFunctionPointerDiscardMethod.Body.GetILProcessor();
603 processor.Emit(OpCodes.Ldsfld, pointerField);
604 var branchPosition = processor.Body.Instructions[processor.Body.Instructions.Count - 1];
605
606 processor.Emit(OpCodes.Ldnull);
607 processor.Emit(OpCodes.Ldftn, targetMethod);
608 processor.Emit(OpCodes.Newobj, injectedDelegate.Methods.First(md => md.IsConstructor && md.Parameters.Count == 2));
609
610 processor.Emit(OpCodes.Call, genericCompileFunctionPointer);
611 processor.Emit(OpCodes.Stloc_0);
612
613 processor.Emit(OpCodes.Ldloca, 0);
614 processor.Emit(OpCodes.Call, genericGetValue);
615
616 processor.Emit(OpCodes.Stsfld, pointerField);
617
618 processor.Emit(OpCodes.Ldarg_0);
619 processor.InsertAfter(branchPosition, Instruction.Create(OpCodes.Brtrue, processor.Body.Instructions[processor.Body.Instructions.Count - 1]));
620 processor.Emit(OpCodes.Ldsfld, pointerField);
621 processor.Emit(OpCodes.Stind_I);
622 processor.Emit(OpCodes.Ret);
623
624 cls.Methods.Add(FixDebugInformation(getFunctionPointerDiscardMethod));
625
626 getFunctionPointerDiscardMethod.CustomAttributes.Add(new CustomAttribute(_burstDiscardAttributeConstructor));
627
628 return getFunctionPointerDiscardMethod;
629 }
630
631 private MethodDefinition CreateGetFunctionPointerMethod(TypeDefinition cls, MethodDefinition getFunctionPointerDiscardMethod)
632 {
633 // Create GetFunctionPointer method:
634 //
635 // public static IntPtr GetFunctionPointer() {
636 // var ptr;
637 // GetFunctionPointerDiscard(ref ptr);
638 // return ptr;
639 // }
640 var getFunctionPointerMethod = new MethodDefinition(GetFunctionPointerName, MethodAttributes.Private | MethodAttributes.HideBySig | MethodAttributes.Static, _typeSystem.IntPtr)
641 {
642 ImplAttributes = MethodImplAttributes.IL | MethodImplAttributes.Managed,
643 DeclaringType = cls
644 };
645
646 getFunctionPointerMethod.Body.Variables.Add(new VariableDefinition(_typeSystem.IntPtr));
647 getFunctionPointerMethod.Body.InitLocals = true;
648
649 var processor = getFunctionPointerMethod.Body.GetILProcessor();
650
651 processor.Emit(OpCodes.Ldc_I4_0);
652 processor.Emit(OpCodes.Conv_I);
653 processor.Emit(OpCodes.Stloc_0);
654 processor.Emit(OpCodes.Ldloca_S, (byte)0);
655 processor.Emit(OpCodes.Call, getFunctionPointerDiscardMethod);
656 processor.Emit(OpCodes.Ldloc_0);
657
658 processor.Emit(OpCodes.Ret);
659
660 cls.Methods.Add(FixDebugInformation(getFunctionPointerMethod));
661
662 return getFunctionPointerMethod;
663 }
664
665 private void ProcessMethodForDirectCall(MethodDefinition burstCompileMethod)
666 {
667 var declaringType = burstCompileMethod.DeclaringType;
668
669 var uniqueSuffix = $"_{burstCompileMethod.MetadataToken.RID:X8}";
670
671 var injectedDelegate = InjectDelegate(declaringType, burstCompileMethod.Name, burstCompileMethod, uniqueSuffix);
672
673 // Create a copy of the original method that will be the actual managed method
674 // The original method is patched at the end of this method to call
675 // the dispatcher that will go to the Burst implementation or the managed method (if in the editor and Burst is disabled)
676 var managedFallbackMethod = new MethodDefinition($"{burstCompileMethod.Name}{PostfixManaged}", burstCompileMethod.Attributes, burstCompileMethod.ReturnType)
677 {
678 DeclaringType = declaringType,
679 ImplAttributes = burstCompileMethod.ImplAttributes,
680 MetadataToken = burstCompileMethod.MetadataToken,
681 };
682
683 // Ensure the CustomAttributes are the same
684 managedFallbackMethod.CustomAttributes.Clear();
685 foreach (var attr in burstCompileMethod.CustomAttributes)
686 {
687 managedFallbackMethod.CustomAttributes.Add(attr);
688 }
689
690 declaringType.Methods.Add(managedFallbackMethod);
691
692 foreach (var parameter in burstCompileMethod.Parameters)
693 {
694 managedFallbackMethod.Parameters.Add(parameter);
695 }
696
697 // Copy the body from the original burst method to the managed fallback, we'll replace the burstCompileMethod body later.
698 managedFallbackMethod.Body.InitLocals = burstCompileMethod.Body.InitLocals;
699 managedFallbackMethod.Body.LocalVarToken = burstCompileMethod.Body.LocalVarToken;
700 managedFallbackMethod.Body.MaxStackSize = burstCompileMethod.Body.MaxStackSize;
701
702 foreach (var variable in burstCompileMethod.Body.Variables)
703 {
704 managedFallbackMethod.Body.Variables.Add(variable);
705 }
706
707 foreach (var instruction in burstCompileMethod.Body.Instructions)
708 {
709 managedFallbackMethod.Body.Instructions.Add(instruction);
710 }
711
712 foreach (var exceptionHandler in burstCompileMethod.Body.ExceptionHandlers)
713 {
714 managedFallbackMethod.Body.ExceptionHandlers.Add(exceptionHandler);
715 }
716
717 managedFallbackMethod.ImplAttributes &= MethodImplAttributes.NoInlining;
718 // 0x0100 is AggressiveInlining
719 managedFallbackMethod.ImplAttributes |= (MethodImplAttributes)0x0100;
720
721 // The method needs to be public because we query for it in the ILPP code.
722 managedFallbackMethod.Attributes &= ~MethodAttributes.Private;
723 managedFallbackMethod.Attributes |= MethodAttributes.Public;
724
725 // private static class (Name_RID.$Postfix)
726 var cls = new TypeDefinition(declaringType.Namespace, $"{burstCompileMethod.Name}{uniqueSuffix}{PostfixBurstDirectCall}",
727 TypeAttributes.NestedAssembly |
728 TypeAttributes.AutoLayout |
729 TypeAttributes.AnsiClass |
730 TypeAttributes.Abstract |
731 TypeAttributes.Sealed |
732 TypeAttributes.BeforeFieldInit
733 )
734 {
735 DeclaringType = declaringType,
736 BaseType = _typeSystem.Object
737 };
738
739 declaringType.NestedTypes.Add(cls);
740
741 // Create Field:
742 //
743 // private static IntPtr Pointer;
744 var pointerField = new FieldDefinition("Pointer", FieldAttributes.Static | FieldAttributes.Private, _typeSystem.IntPtr)
745 {
746 DeclaringType = cls
747 };
748 cls.Fields.Add(pointerField);
749
750 var getFunctionPointerDiscardMethod = CreateGetFunctionPointerDiscardMethod(
751 cls, pointerField,
752 // In the player the function pointer is looked up in a registry by name
753 // so we can't request a `$BurstManaged` function (because it was never compiled, only the toplevel one)
754 // But, it's safe *in the player* to request the toplevel function
755 IsForEditor ? managedFallbackMethod : burstCompileMethod,
756 injectedDelegate);
757 var getFunctionPointerMethod = CreateGetFunctionPointerMethod(cls, getFunctionPointerDiscardMethod);
758
759 // Create the Invoke method based on the original method (same signature)
760 //
761 // public static XXX Invoke(...args) {
762 // if (BurstCompiler.IsEnabled)
763 // {
764 // var funcPtr = GetFunctionPointer();
765 // if (funcPtr != null) return funcPtr(...args);
766 // }
767 // return OriginalMethod(...args);
768 // }
769 var invokeAttributes = managedFallbackMethod.Attributes;
770 invokeAttributes &= ~MethodAttributes.Private;
771 invokeAttributes |= MethodAttributes.Public;
772 var invoke = new MethodDefinition(InvokeName, invokeAttributes, burstCompileMethod.ReturnType)
773 {
774 ImplAttributes = MethodImplAttributes.IL | MethodImplAttributes.Managed,
775 DeclaringType = cls
776 };
777
778 var signature = new CallSite(burstCompileMethod.ReturnType)
779 {
780 CallingConvention = MethodCallingConvention.C
781 };
782
783 foreach (var parameter in burstCompileMethod.Parameters)
784 {
785 invoke.Parameters.Add(parameter);
786 signature.Parameters.Add(parameter);
787 }
788
789 invoke.Body.Variables.Add(new VariableDefinition(_typeSystem.IntPtr));
790 invoke.Body.InitLocals = true;
791
792 var processor = invoke.Body.GetILProcessor();
793 processor.Emit(OpCodes.Call, _burstCompilerIsEnabledMethodDefinition);
794 var branchPosition0 = processor.Body.Instructions[processor.Body.Instructions.Count - 1];
795
796 processor.Emit(OpCodes.Call, getFunctionPointerMethod);
797 processor.Emit(OpCodes.Stloc_0);
798 processor.Emit(OpCodes.Ldloc_0);
799 var branchPosition1 = processor.Body.Instructions[processor.Body.Instructions.Count - 1];
800
801 EmitArguments(processor, invoke);
802 processor.Emit(OpCodes.Ldloc_0);
803 processor.Emit(OpCodes.Calli, signature);
804 processor.Emit(OpCodes.Ret);
805 var previousRet = processor.Body.Instructions[processor.Body.Instructions.Count - 1];
806
807 EmitArguments(processor, invoke);
808 processor.Emit(OpCodes.Call, managedFallbackMethod);
809 processor.Emit(OpCodes.Ret);
810
811 // Insert the branch once we have emitted the instructions
812 processor.InsertAfter(branchPosition0, Instruction.Create(OpCodes.Brfalse, previousRet.Next));
813 processor.InsertAfter(branchPosition1, Instruction.Create(OpCodes.Brfalse, previousRet.Next));
814 cls.Methods.Add(FixDebugInformation(invoke));
815
816 // Final patching of the original method
817 // public static XXX OriginalMethod(...args) {
818 // Name_RID.$Postfix.Invoke(...args);
819 // ret;
820 // }
821 burstCompileMethod.Body = new MethodBody(burstCompileMethod);
822 processor = burstCompileMethod.Body.GetILProcessor();
823 EmitArguments(processor, burstCompileMethod);
824 processor.Emit(OpCodes.Call, invoke);
825 processor.Emit(OpCodes.Ret);
826 FixDebugInformation(burstCompileMethod);
827 }
828
829 private static MethodDefinition FixDebugInformation(MethodDefinition method)
830 {
831 method.DebugInformation.Scope = new ScopeDebugInformation(method.Body.Instructions.First(), method.Body.Instructions.Last());
832 return method;
833 }
834
835 private AssemblyDefinition GetAsmDefinitionFromFile(AssemblyResolver loader, string assemblyName)
836 {
837 if (loader.TryResolve(AssemblyNameReference.Parse(assemblyName), out var result))
838 {
839 return result;
840 }
841 return null;
842 }
843
844 private MethodReference _unityEngineInitializeOnLoadAttributeCtor;
845 private TypeReference _unityEngineRuntimeInitializeLoadType;
846 private FieldDefinition _unityEngineRuntimeInitializeLoadAfterAssemblies;
847 private MethodReference _unityEditorInitilizeOnLoadAttributeCtor;
848
849 private void InitializeBurstAssembly(AssemblyDefinition burstAssembly)
850 {
851 _burstAssembly = burstAssembly;
852
853 var burstCompilerTypeDefinition = burstAssembly.MainModule.GetType("Unity.Burst", "BurstCompiler");
854 _burstCompilerIsEnabledMethodDefinition = _assemblyDefinition.MainModule.ImportReference(burstCompilerTypeDefinition.Methods.FirstOrDefault(x => x.Name == "get_IsEnabled"));
855 _burstCompilerCompileFunctionPointer = _assemblyDefinition.MainModule.ImportReference(burstCompilerTypeDefinition.Methods.FirstOrDefault(x => x.Name == "CompileFunctionPointer"));
856 _burstCompilerOptionsField = _assemblyDefinition.MainModule.ImportReference(burstCompilerTypeDefinition.Fields.FirstOrDefault(x => x.Name == "Options"));
857 _burstCompilerOptionsType = _assemblyDefinition.MainModule.ImportReference(burstAssembly.MainModule.GetType("Unity.Burst", "BurstCompilerOptions"));
858
859 var functionPointerTypeDefinition = burstAssembly.MainModule.GetType("Unity.Burst", "FunctionPointer`1");
860 _functionPointerType = _assemblyDefinition.MainModule.ImportReference(functionPointerTypeDefinition);
861 _functionPointerGetValue = _assemblyDefinition.MainModule.ImportReference(functionPointerTypeDefinition.Methods.FirstOrDefault(x => x.Name == "get_Value"));
862
863 var corLibrary = Loader.Resolve((AssemblyNameReference)_typeSystem.CoreLibrary);
864 _systemDelegateType = _assemblyDefinition.MainModule.ImportReference(corLibrary.MainModule.GetType("System.MulticastDelegate"));
865 _systemASyncCallbackType = _assemblyDefinition.MainModule.ImportReference(corLibrary.MainModule.GetType("System.AsyncCallback"));
866 _systemIASyncResultType = _assemblyDefinition.MainModule.ImportReference(corLibrary.MainModule.GetType("System.IAsyncResult"));
867
868 var asmDef = GetAsmDefinitionFromFile(Loader, "UnityEngine.CoreModule");
869 var runtimeInitializeOnLoadMethodAttribute = asmDef.MainModule.GetType("UnityEngine", "RuntimeInitializeOnLoadMethodAttribute");
870 var runtimeInitializeLoadType = asmDef.MainModule.GetType("UnityEngine", "RuntimeInitializeLoadType");
871
872 var burstDiscardType = asmDef.MainModule.GetType("Unity.Burst", "BurstDiscardAttribute");
873 _burstDiscardAttributeConstructor = _assemblyDefinition.MainModule.ImportReference(burstDiscardType.Methods.First(method => method.Name == ".ctor"));
874
875 _unityEngineInitializeOnLoadAttributeCtor = _assemblyDefinition.MainModule.ImportReference(runtimeInitializeOnLoadMethodAttribute.Methods.FirstOrDefault(x => x.Name == ".ctor" && x.HasParameters));
876 _unityEngineRuntimeInitializeLoadType = _assemblyDefinition.MainModule.ImportReference(runtimeInitializeLoadType);
877 _unityEngineRuntimeInitializeLoadAfterAssemblies = runtimeInitializeLoadType.Fields.FirstOrDefault(x => x.Name == "AfterAssembliesLoaded");
878
879 if (IsForEditor && !_skipInitializeOnLoad)
880 {
881 asmDef = GetAsmDefinitionFromFile(Loader, "UnityEditor.CoreModule");
882 if (asmDef == null)
883 asmDef = GetAsmDefinitionFromFile(Loader, "UnityEditor");
884 var initializeOnLoadMethodAttribute = asmDef.MainModule.GetType("UnityEditor", "InitializeOnLoadMethodAttribute");
885
886 _unityEditorInitilizeOnLoadAttributeCtor = _assemblyDefinition.MainModule.ImportReference(initializeOnLoadMethodAttribute.Methods.FirstOrDefault(x => x.Name == ".ctor" && !x.HasParameters));
887 }
888 }
889
890 private static void EmitArguments(ILProcessor processor, MethodDefinition method)
891 {
892 for (var i = 0; i < method.Parameters.Count; i++)
893 {
894 switch (i)
895 {
896 case 0:
897 processor.Emit(OpCodes.Ldarg_0);
898 break;
899 case 1:
900 processor.Emit(OpCodes.Ldarg_1);
901 break;
902 case 2:
903 processor.Emit(OpCodes.Ldarg_2);
904 break;
905 case 3:
906 processor.Emit(OpCodes.Ldarg_3);
907 break;
908 default:
909 if (i <= 255)
910 {
911 processor.Emit(OpCodes.Ldarg_S, (byte)i);
912 }
913 else
914 {
915 processor.Emit(OpCodes.Ldarg, i);
916 }
917 break;
918 }
919 }
920 }
921
922 private static bool TryGetBurstCompileAttribute(ICustomAttributeProvider provider, out CustomAttribute customAttribute)
923 {
924 if (provider.HasCustomAttributes)
925 {
926 foreach (var customAttr in provider.CustomAttributes)
927 {
928 if (customAttr.Constructor.DeclaringType.Name == "BurstCompileAttribute")
929 {
930 customAttribute = customAttr;
931 return true;
932 }
933 }
934 }
935 customAttribute = null;
936 return false;
937 }
938 }
939}