A game about forced loneliness, made by TACStudios
1#if UNITY_EDITOR
2using System;
3using System.Collections.Generic;
4using System.Diagnostics;
5using System.Globalization;
6using System.Linq;
7using System.Reflection;
8using System.Runtime.CompilerServices;
9using Unity.Jobs.LowLevel.Unsafe;
10using UnityEditor;
11using UnityEditor.Compilation;
12using Debug = UnityEngine.Debug;
13
14[assembly: InternalsVisibleTo("Unity.Burst.Editor.Tests")]
15namespace Unity.Burst.Editor
16{
17 using static BurstCompilerOptions;
18
19 internal static class BurstReflection
20 {
21 // The TypeCache API was added in 2019.2. So there are two versions of FindExecuteMethods,
22 // one that uses TypeCache and one that doesn't.
23 public static FindExecuteMethodsResult FindExecuteMethods(List<System.Reflection.Assembly> assemblyList, BurstReflectionAssemblyOptions options)
24 {
25 var methodsToCompile = new List<BurstCompileTarget>();
26 var methodsToCompileSet = new HashSet<MethodInfo>();
27 var logMessages = new List<LogMessage>();
28 var interfaceToProducer = new Dictionary<Type, Type>();
29
30 var assemblySet = new HashSet<System.Reflection.Assembly>(assemblyList);
31
32 void AddTarget(BurstCompileTarget target)
33 {
34 if (target.Method.Name.EndsWith("$BurstManaged")) return;
35
36 // We will not try to record more than once a method in the methods to compile
37 // This can happen if a job interface is inheriting from another job interface which are using in the end the same
38 // job producer type
39 if (!target.IsStaticMethod && !methodsToCompileSet.Add(target.Method))
40 {
41 return;
42 }
43
44 if (options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies) &&
45 target.JobType.Assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name)))
46 {
47 return;
48 }
49
50 methodsToCompile.Add(target);
51 }
52
53 var staticMethodTypes = new HashSet<Type>();
54
55 // -------------------------------------------
56 // Find job structs using TypeCache.
57 // -------------------------------------------
58
59 var jobProducerImplementations = TypeCache.GetTypesWithAttribute<JobProducerTypeAttribute>();
60 foreach (var jobProducerImplementation in jobProducerImplementations)
61 {
62 var attrs = jobProducerImplementation.GetCustomAttributes(typeof(JobProducerTypeAttribute), false);
63 if (attrs.Length == 0)
64 {
65 continue;
66 }
67
68 staticMethodTypes.Add(jobProducerImplementation);
69
70 var attr = (JobProducerTypeAttribute)attrs[0];
71 interfaceToProducer.Add(jobProducerImplementation, attr.ProducerType);
72 }
73
74 foreach (var jobProducerImplementation in jobProducerImplementations)
75 {
76 if (!jobProducerImplementation.IsInterface)
77 {
78 continue;
79 }
80
81 var jobTypes = TypeCache.GetTypesDerivedFrom(jobProducerImplementation);
82
83 foreach (var jobType in jobTypes)
84 {
85 if (jobType.IsGenericType || !jobType.IsValueType)
86 {
87 continue;
88 }
89
90 ScanJobType(jobType, interfaceToProducer, logMessages, AddTarget);
91 }
92 }
93
94 // -------------------------------------------
95 // Find static methods using TypeCache.
96 // -------------------------------------------
97
98 void AddStaticMethods(TypeCache.MethodCollection methods)
99 {
100 foreach (var method in methods)
101 {
102 if (HasBurstCompileAttribute(method.DeclaringType))
103 {
104 staticMethodTypes.Add(method.DeclaringType);
105
106 // NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`)
107 // We are only working on plain type or generic type instance!
108 if (!method.DeclaringType.IsGenericTypeDefinition &&
109 method.IsStatic &&
110 !method.ContainsGenericParameters)
111 {
112 AddTarget(new BurstCompileTarget(method, method.DeclaringType, null, true));
113 }
114 }
115 }
116 }
117
118 // Add [BurstCompile] static methods.
119 AddStaticMethods(TypeCache.GetMethodsWithAttribute<BurstCompileAttribute>());
120
121 // Add [TestCompiler] static methods.
122 if (!options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies))
123 {
124 var testCompilerAttributeType = Type.GetType("Burst.Compiler.IL.Tests.TestCompilerAttribute, Unity.Burst.Tests.UnitTests, Version=0.0.0.0, Culture=neutral, PublicKeyToken=null");
125 if (testCompilerAttributeType != null)
126 {
127 AddStaticMethods(TypeCache.GetMethodsWithAttribute(testCompilerAttributeType));
128 }
129 }
130
131 // -------------------------------------------
132 // Find job types and static methods based on
133 // generic instances types. These will not be
134 // found by the TypeCache scanning above.
135 // -------------------------------------------
136 FindExecuteMethodsForGenericInstances(
137 assemblySet,
138 staticMethodTypes,
139 interfaceToProducer,
140 AddTarget,
141 logMessages);
142
143 return new FindExecuteMethodsResult(methodsToCompile, logMessages);
144 }
145
146 private static void ScanJobType(
147 Type jobType,
148 Dictionary<Type, Type> interfaceToProducer,
149 List<LogMessage> logMessages,
150 Action<BurstCompileTarget> addTarget)
151 {
152 foreach (var interfaceType in jobType.GetInterfaces())
153 {
154 var genericLessInterface = interfaceType;
155 if (interfaceType.IsGenericType)
156 {
157 genericLessInterface = interfaceType.GetGenericTypeDefinition();
158 }
159
160 if (interfaceToProducer.TryGetValue(genericLessInterface, out var foundProducer))
161 {
162 var genericParams = new List<Type> { jobType };
163 if (interfaceType.IsGenericType)
164 {
165 genericParams.AddRange(interfaceType.GenericTypeArguments);
166 }
167
168 try
169 {
170 var executeType = foundProducer.MakeGenericType(genericParams.ToArray());
171 var executeMethod = executeType.GetMethod("Execute", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
172 if (executeMethod == null)
173 {
174 throw new InvalidOperationException($"Burst reflection error. The type `{executeType}` does not contain an `Execute` method");
175 }
176
177 addTarget(new BurstCompileTarget(executeMethod, jobType, interfaceType, false));
178 }
179 catch (Exception ex)
180 {
181 logMessages.Add(new LogMessage(ex));
182 }
183 }
184 }
185 }
186
187 private static void FindExecuteMethodsForGenericInstances(
188 HashSet<System.Reflection.Assembly> assemblyList,
189 HashSet<Type> staticMethodTypes,
190 Dictionary<Type, Type> interfaceToProducer,
191 Action<BurstCompileTarget> addTarget,
192 List<LogMessage> logMessages)
193 {
194 var valueTypes = new List<TypeToVisit>();
195
196 //Debug.Log("Filtered Assembly List: " + string.Join(", ", assemblyList.Select(assembly => assembly.GetName().Name)));
197
198 // Find all ways to execute job types (via producer attributes)
199 var typesVisited = new HashSet<string>();
200 var typesToVisit = new HashSet<string>();
201 var allTypesAssembliesCollected = new HashSet<Type>();
202 foreach (var assembly in assemblyList)
203 {
204 var types = new List<Type>();
205 try
206 {
207 // Collect all generic type instances (excluding indirect instances)
208 CollectGenericTypeInstances(
209 assembly,
210 x => assemblyList.Contains(x.Assembly),
211 types,
212 allTypesAssembliesCollected);
213 }
214 catch (Exception ex)
215 {
216 logMessages.Add(new LogMessage(LogType.Warning, "Unexpected exception while collecting types in assembly `" + assembly.FullName + "` Exception: " + ex));
217 }
218
219 for (var i = 0; i < types.Count; i++)
220 {
221 var t = types[i];
222 if (typesToVisit.Add(t.AssemblyQualifiedName))
223 {
224 // Because the list of types returned by CollectGenericTypeInstances does not detect nested generic classes that are not
225 // used explicitly, we need to create them if a declaring type is actually used
226 // so for example if we have:
227 // class MyClass<T> { class MyNestedClass { } }
228 // class MyDerived : MyClass<int> { }
229 // The CollectGenericTypeInstances will return typically the type MyClass<int>, but will not list MyClass<int>.MyNestedClass
230 // So the following code is correcting this in order to fully query the full graph of generic instance types, including indirect types
231 var nestedTypes = t.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic);
232 foreach (var nestedType in nestedTypes)
233 {
234 if (t.IsGenericType && !t.IsGenericTypeDefinition)
235 {
236 var parentGenericTypeArguments = t.GetGenericArguments();
237 // Only create nested types that are closed generic types (full generic instance types)
238 // It happens if for example the parent class is `class MClass<T> { class MyNestedGeneric<T1> {} }`
239 // In that case, MyNestedGeneric<T1> is opened in the context of MClass<int>, so we don't process them
240 if (nestedType.GetGenericArguments().Length == parentGenericTypeArguments.Length)
241 {
242 try
243 {
244 var instanceNestedType = nestedType.MakeGenericType(parentGenericTypeArguments);
245 types.Add(instanceNestedType);
246 }
247 catch (Exception ex)
248 {
249 var error = $"Unexpected Burst Inspector error. Invalid generic type instance. Trying to instantiate the generic type {nestedType.FullName} with the generic arguments <{string.Join(", ", parentGenericTypeArguments.Select(x => x.FullName))}> is not supported: {ex}";
250 logMessages.Add(new LogMessage(LogType.Warning, error));
251 }
252 }
253 }
254 else
255 {
256 types.Add(nestedType);
257 }
258 }
259 }
260 }
261
262 foreach (var t in types)
263 {
264 // If the type has been already visited, don't try to visit it
265 if (!typesVisited.Add(t.AssemblyQualifiedName) || (t.IsGenericTypeDefinition && !t.IsInterface))
266 {
267 continue;
268 }
269
270 try
271 {
272 // collect methods with types having a [BurstCompile] attribute
273 var staticMethodDeclaringType = t;
274 if (t.IsGenericType)
275 {
276 staticMethodDeclaringType = t.GetGenericTypeDefinition();
277 }
278 bool visitStaticMethods = staticMethodTypes.Contains(staticMethodDeclaringType);
279 bool isValueType = false;
280
281 if (t.IsValueType)
282 {
283 // NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`)
284 // We are only working on plain type or generic type instance!
285 if (!t.IsGenericTypeDefinition)
286 isValueType = true;
287 }
288
289 if (isValueType || visitStaticMethods)
290 {
291 valueTypes.Add(new TypeToVisit(t, visitStaticMethods));
292 }
293 }
294 catch (Exception ex)
295 {
296 logMessages.Add(new LogMessage(LogType.Warning,
297 "Unexpected exception while inspecting type `" + t +
298 "` IsConstructedGenericType: " + t.IsConstructedGenericType +
299 " IsGenericTypeDef: " + t.IsGenericTypeDefinition +
300 " IsGenericParam: " + t.IsGenericParameter +
301 " Exception: " + ex));
302 }
303 }
304 }
305
306 // Revisit all types to find things that are compilable using the above producers.
307 foreach (var typePair in valueTypes)
308 {
309 var type = typePair.Type;
310
311 // collect static [BurstCompile] methods
312 if (typePair.CollectStaticMethods)
313 {
314 try
315 {
316 var methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic);
317 foreach (var method in methods)
318 {
319 if (HasBurstCompileAttribute(method))
320 {
321 addTarget(new BurstCompileTarget(method, type, null, true));
322 }
323 }
324 }
325 catch (Exception ex)
326 {
327 logMessages.Add(new LogMessage(ex));
328 }
329 }
330
331 // If the type is not a value type, we don't need to proceed with struct Jobs
332 if (!type.IsValueType)
333 {
334 continue;
335 }
336
337 ScanJobType(type, interfaceToProducer, logMessages, addTarget);
338 }
339 }
340
341 public sealed class FindExecuteMethodsResult
342 {
343 public readonly List<BurstCompileTarget> CompileTargets;
344 public readonly List<LogMessage> LogMessages;
345
346 public FindExecuteMethodsResult(List<BurstCompileTarget> compileTargets, List<LogMessage> logMessages)
347 {
348 CompileTargets = compileTargets;
349 LogMessages = logMessages;
350 }
351 }
352
353 public sealed class LogMessage
354 {
355 public readonly LogType LogType;
356 public readonly string Message;
357 public readonly Exception Exception;
358
359 public LogMessage(LogType logType, string message)
360 {
361 LogType = logType;
362 Message = message;
363 }
364
365 public LogMessage(Exception exception)
366 {
367 LogType = LogType.Exception;
368 Exception = exception;
369 }
370 }
371
372 public enum LogType
373 {
374 Warning,
375 Exception,
376 }
377
378 /// <summary>
379 /// This method exists solely to ensure that the static constructor has been called.
380 /// </summary>
381 public static void EnsureInitialized() { }
382
383 public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobs;
384 public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies;
385
386 /// <summary>
387 /// Collects (and caches) all editor assemblies - transitively.
388 /// </summary>
389 static BurstReflection()
390 {
391 EditorAssembliesThatCanPossiblyContainJobs = new List<System.Reflection.Assembly>();
392 EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies = new List<System.Reflection.Assembly>();
393
394 // TODO: Not sure there is a better way to match assemblies returned by CompilationPipeline.GetAssemblies
395 // with runtime assemblies contained in the AppDomain.CurrentDomain.GetAssemblies()
396
397 // Filter the assemblies
398 var assemblyList = CompilationPipeline.GetAssemblies(AssembliesType.Editor);
399
400 var assemblyNames = new HashSet<string>();
401 foreach (var assembly in assemblyList)
402 {
403 CollectAssemblyNames(assembly, assemblyNames);
404 }
405
406 var allAssemblies = new HashSet<System.Reflection.Assembly>();
407 foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
408 {
409 if (!assemblyNames.Contains(assembly.GetName().Name))
410 {
411 continue;
412 }
413 CollectAssembly(assembly, allAssemblies);
414 }
415 }
416
417 // For an assembly to contain something "interesting" when we're scanning for things to compile,
418 // it needs to either:
419 // (a) be one of these assemblies, or
420 // (b) reference one of these assemblies
421 private static readonly string[] ScanMarkerAssemblies = new[]
422 {
423 // Contains [BurstCompile] attribute
424 "Unity.Burst",
425
426 // Contains [JobProducerType] attribute
427 "UnityEngine.CoreModule"
428 };
429
430 private static void CollectAssembly(System.Reflection.Assembly assembly, HashSet<System.Reflection.Assembly> collect)
431 {
432 if (!collect.Add(assembly))
433 {
434 return;
435 }
436
437 var referencedAssemblies = assembly.GetReferencedAssemblies();
438
439 var shouldCollectReferences = false;
440
441 var name = assembly.GetName().Name;
442 if (ScanMarkerAssemblies.Contains(name) || referencedAssemblies.Any(x => ScanMarkerAssemblies.Contains(x.Name)))
443 {
444 EditorAssembliesThatCanPossiblyContainJobs.Add(assembly);
445 shouldCollectReferences = true;
446
447 if (!assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name)))
448 {
449 EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies.Add(assembly);
450 }
451 }
452
453 if (!shouldCollectReferences)
454 {
455 return;
456 }
457
458 foreach (var assemblyName in referencedAssemblies)
459 {
460 try
461 {
462 CollectAssembly(System.Reflection.Assembly.Load(assemblyName), collect);
463 }
464 catch (Exception)
465 {
466 if (BurstLoader.IsDebugging)
467 {
468 Debug.LogWarning("Could not load assembly " + assemblyName);
469 }
470 }
471 }
472 }
473
474 private static bool IsNUnitDll(string value)
475 {
476 return CultureInfo.InvariantCulture.CompareInfo.IndexOf(value, "nunit.framework") >= 0;
477 }
478
479 private static void CollectAssemblyNames(UnityEditor.Compilation.Assembly assembly, HashSet<string> collect)
480 {
481 if (assembly == null || assembly.name == null) return;
482
483 if (!collect.Add(assembly.name))
484 {
485 return;
486 }
487
488 foreach (var assemblyRef in assembly.assemblyReferences)
489 {
490 CollectAssemblyNames(assemblyRef, collect);
491 }
492 }
493
494 /// <summary>
495 /// Gets the list of concrete generic type instances used in an assembly.
496 /// See remarks
497 /// </summary>
498 /// <param name="assembly">The assembly</param>
499 /// <param name="types"></param>
500 /// <returns>The list of generic type instances</returns>
501 /// <remarks>
502 /// Note that this method fetchs only direct type instances but
503 /// cannot fetch transitive generic type instances.
504 /// </remarks>
505 private static void CollectGenericTypeInstances(
506 System.Reflection.Assembly assembly,
507 Func<Type, bool> typeFilter,
508 List<Type> types,
509 HashSet<Type> visited)
510 {
511 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
512 // WARNING: THIS CODE HAS TO BE MAINTAINED IN SYNC WITH BclApp.cs
513 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
514
515 // From: https://gist.github.com/xoofx/710aaf86e0e8c81649d1261b1ef9590e
516 if (assembly == null) throw new ArgumentNullException(nameof(assembly));
517 const int mdMaxCount = 1 << 24;
518 foreach (var module in assembly.Modules)
519 {
520 for (int i = 1; i < mdMaxCount; i++)
521 {
522 try
523 {
524 // Token base id for TypeSpec
525 const int mdTypeSpec = 0x1B000000;
526 var type = module.ResolveType(mdTypeSpec | i);
527 CollectGenericTypeInstances(type, types, visited, typeFilter);
528 }
529 catch (ArgumentOutOfRangeException)
530 {
531 break;
532 }
533 catch (ArgumentException)
534 {
535 // Can happen on ResolveType on certain generic types, so we continue
536 }
537 }
538
539 for (int i = 1; i < mdMaxCount; i++)
540 {
541 try
542 {
543 // Token base id for MethodSpec
544 const int mdMethodSpec = 0x2B000000;
545 var method = module.ResolveMethod(mdMethodSpec | i);
546 var genericArgs = method.GetGenericArguments();
547 foreach (var genArgType in genericArgs)
548 {
549 CollectGenericTypeInstances(genArgType, types, visited, typeFilter);
550 }
551 }
552 catch (ArgumentOutOfRangeException)
553 {
554 break;
555 }
556 catch (ArgumentException)
557 {
558 // Can happen on ResolveType on certain generic types, so we continue
559 }
560 }
561
562 for (int i = 1; i < mdMaxCount; i++)
563 {
564 try
565 {
566 // Token base id for Field
567 const int mdField = 0x04000000;
568 var field = module.ResolveField(mdField | i);
569 CollectGenericTypeInstances(field.FieldType, types, visited, typeFilter);
570 }
571 catch (ArgumentOutOfRangeException)
572 {
573 break;
574 }
575 catch (ArgumentException)
576 {
577 // Can happen on ResolveType on certain generic types, so we continue
578 }
579 }
580 }
581
582 // Scan for types used in constructor arguments to assembly-level attributes,
583 // such as [RegisterGenericJobType(typeof(...))].
584 foreach (var customAttribute in assembly.CustomAttributes)
585 {
586 foreach (var argument in customAttribute.ConstructorArguments)
587 {
588 if (argument.ArgumentType == typeof(Type))
589 {
590 CollectGenericTypeInstances((Type)argument.Value, types, visited, typeFilter);
591 }
592 }
593 }
594 }
595
596 private static void CollectGenericTypeInstances(
597 Type type,
598 List<Type> types,
599 HashSet<Type> visited,
600 Func<Type, bool> typeFilter)
601 {
602 if (type.IsPrimitive) return;
603 if (!visited.Add(type)) return;
604
605 // Add only concrete types
606 if (type.IsConstructedGenericType && !type.ContainsGenericParameters && typeFilter(type))
607 {
608 types.Add(type);
609 }
610
611 // Collect recursively generic type arguments
612 var genericTypeArguments = type.GenericTypeArguments;
613 foreach (var genericTypeArgument in genericTypeArguments)
614 {
615 if (!genericTypeArgument.IsPrimitive)
616 {
617 CollectGenericTypeInstances(genericTypeArgument, types, visited, typeFilter);
618 }
619 }
620 }
621
622 [DebuggerDisplay("{Type} (static methods: {CollectStaticMethods})")]
623 private struct TypeToVisit
624 {
625 public TypeToVisit(Type type, bool collectStaticMethods)
626 {
627 Type = type;
628 CollectStaticMethods = collectStaticMethods;
629 }
630
631 public readonly Type Type;
632
633 public readonly bool CollectStaticMethods;
634 }
635 }
636
637 [Flags]
638 internal enum BurstReflectionAssemblyOptions
639 {
640 None = 0,
641 ExcludeTestAssemblies = 1,
642 }
643}
644#endif