A game about forced loneliness, made by TACStudios
at master 644 lines 27 kB view raw
1#if UNITY_EDITOR 2using System; 3using System.Collections.Generic; 4using System.Diagnostics; 5using System.Globalization; 6using System.Linq; 7using System.Reflection; 8using System.Runtime.CompilerServices; 9using Unity.Jobs.LowLevel.Unsafe; 10using UnityEditor; 11using UnityEditor.Compilation; 12using Debug = UnityEngine.Debug; 13 14[assembly: InternalsVisibleTo("Unity.Burst.Editor.Tests")] 15namespace Unity.Burst.Editor 16{ 17 using static BurstCompilerOptions; 18 19 internal static class BurstReflection 20 { 21 // The TypeCache API was added in 2019.2. So there are two versions of FindExecuteMethods, 22 // one that uses TypeCache and one that doesn't. 23 public static FindExecuteMethodsResult FindExecuteMethods(List<System.Reflection.Assembly> assemblyList, BurstReflectionAssemblyOptions options) 24 { 25 var methodsToCompile = new List<BurstCompileTarget>(); 26 var methodsToCompileSet = new HashSet<MethodInfo>(); 27 var logMessages = new List<LogMessage>(); 28 var interfaceToProducer = new Dictionary<Type, Type>(); 29 30 var assemblySet = new HashSet<System.Reflection.Assembly>(assemblyList); 31 32 void AddTarget(BurstCompileTarget target) 33 { 34 if (target.Method.Name.EndsWith("$BurstManaged")) return; 35 36 // We will not try to record more than once a method in the methods to compile 37 // This can happen if a job interface is inheriting from another job interface which are using in the end the same 38 // job producer type 39 if (!target.IsStaticMethod && !methodsToCompileSet.Add(target.Method)) 40 { 41 return; 42 } 43 44 if (options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies) && 45 target.JobType.Assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name))) 46 { 47 return; 48 } 49 50 methodsToCompile.Add(target); 51 } 52 53 var staticMethodTypes = new HashSet<Type>(); 54 55 // ------------------------------------------- 56 // Find job structs using TypeCache. 57 // ------------------------------------------- 58 59 var jobProducerImplementations = TypeCache.GetTypesWithAttribute<JobProducerTypeAttribute>(); 60 foreach (var jobProducerImplementation in jobProducerImplementations) 61 { 62 var attrs = jobProducerImplementation.GetCustomAttributes(typeof(JobProducerTypeAttribute), false); 63 if (attrs.Length == 0) 64 { 65 continue; 66 } 67 68 staticMethodTypes.Add(jobProducerImplementation); 69 70 var attr = (JobProducerTypeAttribute)attrs[0]; 71 interfaceToProducer.Add(jobProducerImplementation, attr.ProducerType); 72 } 73 74 foreach (var jobProducerImplementation in jobProducerImplementations) 75 { 76 if (!jobProducerImplementation.IsInterface) 77 { 78 continue; 79 } 80 81 var jobTypes = TypeCache.GetTypesDerivedFrom(jobProducerImplementation); 82 83 foreach (var jobType in jobTypes) 84 { 85 if (jobType.IsGenericType || !jobType.IsValueType) 86 { 87 continue; 88 } 89 90 ScanJobType(jobType, interfaceToProducer, logMessages, AddTarget); 91 } 92 } 93 94 // ------------------------------------------- 95 // Find static methods using TypeCache. 96 // ------------------------------------------- 97 98 void AddStaticMethods(TypeCache.MethodCollection methods) 99 { 100 foreach (var method in methods) 101 { 102 if (HasBurstCompileAttribute(method.DeclaringType)) 103 { 104 staticMethodTypes.Add(method.DeclaringType); 105 106 // NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`) 107 // We are only working on plain type or generic type instance! 108 if (!method.DeclaringType.IsGenericTypeDefinition && 109 method.IsStatic && 110 !method.ContainsGenericParameters) 111 { 112 AddTarget(new BurstCompileTarget(method, method.DeclaringType, null, true)); 113 } 114 } 115 } 116 } 117 118 // Add [BurstCompile] static methods. 119 AddStaticMethods(TypeCache.GetMethodsWithAttribute<BurstCompileAttribute>()); 120 121 // Add [TestCompiler] static methods. 122 if (!options.HasFlag(BurstReflectionAssemblyOptions.ExcludeTestAssemblies)) 123 { 124 var testCompilerAttributeType = Type.GetType("Burst.Compiler.IL.Tests.TestCompilerAttribute, Unity.Burst.Tests.UnitTests, Version=0.0.0.0, Culture=neutral, PublicKeyToken=null"); 125 if (testCompilerAttributeType != null) 126 { 127 AddStaticMethods(TypeCache.GetMethodsWithAttribute(testCompilerAttributeType)); 128 } 129 } 130 131 // ------------------------------------------- 132 // Find job types and static methods based on 133 // generic instances types. These will not be 134 // found by the TypeCache scanning above. 135 // ------------------------------------------- 136 FindExecuteMethodsForGenericInstances( 137 assemblySet, 138 staticMethodTypes, 139 interfaceToProducer, 140 AddTarget, 141 logMessages); 142 143 return new FindExecuteMethodsResult(methodsToCompile, logMessages); 144 } 145 146 private static void ScanJobType( 147 Type jobType, 148 Dictionary<Type, Type> interfaceToProducer, 149 List<LogMessage> logMessages, 150 Action<BurstCompileTarget> addTarget) 151 { 152 foreach (var interfaceType in jobType.GetInterfaces()) 153 { 154 var genericLessInterface = interfaceType; 155 if (interfaceType.IsGenericType) 156 { 157 genericLessInterface = interfaceType.GetGenericTypeDefinition(); 158 } 159 160 if (interfaceToProducer.TryGetValue(genericLessInterface, out var foundProducer)) 161 { 162 var genericParams = new List<Type> { jobType }; 163 if (interfaceType.IsGenericType) 164 { 165 genericParams.AddRange(interfaceType.GenericTypeArguments); 166 } 167 168 try 169 { 170 var executeType = foundProducer.MakeGenericType(genericParams.ToArray()); 171 var executeMethod = executeType.GetMethod("Execute", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static); 172 if (executeMethod == null) 173 { 174 throw new InvalidOperationException($"Burst reflection error. The type `{executeType}` does not contain an `Execute` method"); 175 } 176 177 addTarget(new BurstCompileTarget(executeMethod, jobType, interfaceType, false)); 178 } 179 catch (Exception ex) 180 { 181 logMessages.Add(new LogMessage(ex)); 182 } 183 } 184 } 185 } 186 187 private static void FindExecuteMethodsForGenericInstances( 188 HashSet<System.Reflection.Assembly> assemblyList, 189 HashSet<Type> staticMethodTypes, 190 Dictionary<Type, Type> interfaceToProducer, 191 Action<BurstCompileTarget> addTarget, 192 List<LogMessage> logMessages) 193 { 194 var valueTypes = new List<TypeToVisit>(); 195 196 //Debug.Log("Filtered Assembly List: " + string.Join(", ", assemblyList.Select(assembly => assembly.GetName().Name))); 197 198 // Find all ways to execute job types (via producer attributes) 199 var typesVisited = new HashSet<string>(); 200 var typesToVisit = new HashSet<string>(); 201 var allTypesAssembliesCollected = new HashSet<Type>(); 202 foreach (var assembly in assemblyList) 203 { 204 var types = new List<Type>(); 205 try 206 { 207 // Collect all generic type instances (excluding indirect instances) 208 CollectGenericTypeInstances( 209 assembly, 210 x => assemblyList.Contains(x.Assembly), 211 types, 212 allTypesAssembliesCollected); 213 } 214 catch (Exception ex) 215 { 216 logMessages.Add(new LogMessage(LogType.Warning, "Unexpected exception while collecting types in assembly `" + assembly.FullName + "` Exception: " + ex)); 217 } 218 219 for (var i = 0; i < types.Count; i++) 220 { 221 var t = types[i]; 222 if (typesToVisit.Add(t.AssemblyQualifiedName)) 223 { 224 // Because the list of types returned by CollectGenericTypeInstances does not detect nested generic classes that are not 225 // used explicitly, we need to create them if a declaring type is actually used 226 // so for example if we have: 227 // class MyClass<T> { class MyNestedClass { } } 228 // class MyDerived : MyClass<int> { } 229 // The CollectGenericTypeInstances will return typically the type MyClass<int>, but will not list MyClass<int>.MyNestedClass 230 // So the following code is correcting this in order to fully query the full graph of generic instance types, including indirect types 231 var nestedTypes = t.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic); 232 foreach (var nestedType in nestedTypes) 233 { 234 if (t.IsGenericType && !t.IsGenericTypeDefinition) 235 { 236 var parentGenericTypeArguments = t.GetGenericArguments(); 237 // Only create nested types that are closed generic types (full generic instance types) 238 // It happens if for example the parent class is `class MClass<T> { class MyNestedGeneric<T1> {} }` 239 // In that case, MyNestedGeneric<T1> is opened in the context of MClass<int>, so we don't process them 240 if (nestedType.GetGenericArguments().Length == parentGenericTypeArguments.Length) 241 { 242 try 243 { 244 var instanceNestedType = nestedType.MakeGenericType(parentGenericTypeArguments); 245 types.Add(instanceNestedType); 246 } 247 catch (Exception ex) 248 { 249 var error = $"Unexpected Burst Inspector error. Invalid generic type instance. Trying to instantiate the generic type {nestedType.FullName} with the generic arguments <{string.Join(", ", parentGenericTypeArguments.Select(x => x.FullName))}> is not supported: {ex}"; 250 logMessages.Add(new LogMessage(LogType.Warning, error)); 251 } 252 } 253 } 254 else 255 { 256 types.Add(nestedType); 257 } 258 } 259 } 260 } 261 262 foreach (var t in types) 263 { 264 // If the type has been already visited, don't try to visit it 265 if (!typesVisited.Add(t.AssemblyQualifiedName) || (t.IsGenericTypeDefinition && !t.IsInterface)) 266 { 267 continue; 268 } 269 270 try 271 { 272 // collect methods with types having a [BurstCompile] attribute 273 var staticMethodDeclaringType = t; 274 if (t.IsGenericType) 275 { 276 staticMethodDeclaringType = t.GetGenericTypeDefinition(); 277 } 278 bool visitStaticMethods = staticMethodTypes.Contains(staticMethodDeclaringType); 279 bool isValueType = false; 280 281 if (t.IsValueType) 282 { 283 // NOTE: Make sure that we don't use a value type generic definition (e.g `class Outer<T> { struct Inner { } }`) 284 // We are only working on plain type or generic type instance! 285 if (!t.IsGenericTypeDefinition) 286 isValueType = true; 287 } 288 289 if (isValueType || visitStaticMethods) 290 { 291 valueTypes.Add(new TypeToVisit(t, visitStaticMethods)); 292 } 293 } 294 catch (Exception ex) 295 { 296 logMessages.Add(new LogMessage(LogType.Warning, 297 "Unexpected exception while inspecting type `" + t + 298 "` IsConstructedGenericType: " + t.IsConstructedGenericType + 299 " IsGenericTypeDef: " + t.IsGenericTypeDefinition + 300 " IsGenericParam: " + t.IsGenericParameter + 301 " Exception: " + ex)); 302 } 303 } 304 } 305 306 // Revisit all types to find things that are compilable using the above producers. 307 foreach (var typePair in valueTypes) 308 { 309 var type = typePair.Type; 310 311 // collect static [BurstCompile] methods 312 if (typePair.CollectStaticMethods) 313 { 314 try 315 { 316 var methods = type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic); 317 foreach (var method in methods) 318 { 319 if (HasBurstCompileAttribute(method)) 320 { 321 addTarget(new BurstCompileTarget(method, type, null, true)); 322 } 323 } 324 } 325 catch (Exception ex) 326 { 327 logMessages.Add(new LogMessage(ex)); 328 } 329 } 330 331 // If the type is not a value type, we don't need to proceed with struct Jobs 332 if (!type.IsValueType) 333 { 334 continue; 335 } 336 337 ScanJobType(type, interfaceToProducer, logMessages, addTarget); 338 } 339 } 340 341 public sealed class FindExecuteMethodsResult 342 { 343 public readonly List<BurstCompileTarget> CompileTargets; 344 public readonly List<LogMessage> LogMessages; 345 346 public FindExecuteMethodsResult(List<BurstCompileTarget> compileTargets, List<LogMessage> logMessages) 347 { 348 CompileTargets = compileTargets; 349 LogMessages = logMessages; 350 } 351 } 352 353 public sealed class LogMessage 354 { 355 public readonly LogType LogType; 356 public readonly string Message; 357 public readonly Exception Exception; 358 359 public LogMessage(LogType logType, string message) 360 { 361 LogType = logType; 362 Message = message; 363 } 364 365 public LogMessage(Exception exception) 366 { 367 LogType = LogType.Exception; 368 Exception = exception; 369 } 370 } 371 372 public enum LogType 373 { 374 Warning, 375 Exception, 376 } 377 378 /// <summary> 379 /// This method exists solely to ensure that the static constructor has been called. 380 /// </summary> 381 public static void EnsureInitialized() { } 382 383 public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobs; 384 public static readonly List<System.Reflection.Assembly> EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies; 385 386 /// <summary> 387 /// Collects (and caches) all editor assemblies - transitively. 388 /// </summary> 389 static BurstReflection() 390 { 391 EditorAssembliesThatCanPossiblyContainJobs = new List<System.Reflection.Assembly>(); 392 EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies = new List<System.Reflection.Assembly>(); 393 394 // TODO: Not sure there is a better way to match assemblies returned by CompilationPipeline.GetAssemblies 395 // with runtime assemblies contained in the AppDomain.CurrentDomain.GetAssemblies() 396 397 // Filter the assemblies 398 var assemblyList = CompilationPipeline.GetAssemblies(AssembliesType.Editor); 399 400 var assemblyNames = new HashSet<string>(); 401 foreach (var assembly in assemblyList) 402 { 403 CollectAssemblyNames(assembly, assemblyNames); 404 } 405 406 var allAssemblies = new HashSet<System.Reflection.Assembly>(); 407 foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) 408 { 409 if (!assemblyNames.Contains(assembly.GetName().Name)) 410 { 411 continue; 412 } 413 CollectAssembly(assembly, allAssemblies); 414 } 415 } 416 417 // For an assembly to contain something "interesting" when we're scanning for things to compile, 418 // it needs to either: 419 // (a) be one of these assemblies, or 420 // (b) reference one of these assemblies 421 private static readonly string[] ScanMarkerAssemblies = new[] 422 { 423 // Contains [BurstCompile] attribute 424 "Unity.Burst", 425 426 // Contains [JobProducerType] attribute 427 "UnityEngine.CoreModule" 428 }; 429 430 private static void CollectAssembly(System.Reflection.Assembly assembly, HashSet<System.Reflection.Assembly> collect) 431 { 432 if (!collect.Add(assembly)) 433 { 434 return; 435 } 436 437 var referencedAssemblies = assembly.GetReferencedAssemblies(); 438 439 var shouldCollectReferences = false; 440 441 var name = assembly.GetName().Name; 442 if (ScanMarkerAssemblies.Contains(name) || referencedAssemblies.Any(x => ScanMarkerAssemblies.Contains(x.Name))) 443 { 444 EditorAssembliesThatCanPossiblyContainJobs.Add(assembly); 445 shouldCollectReferences = true; 446 447 if (!assembly.GetReferencedAssemblies().Any(x => IsNUnitDll(x.Name))) 448 { 449 EditorAssembliesThatCanPossiblyContainJobsExcludingTestAssemblies.Add(assembly); 450 } 451 } 452 453 if (!shouldCollectReferences) 454 { 455 return; 456 } 457 458 foreach (var assemblyName in referencedAssemblies) 459 { 460 try 461 { 462 CollectAssembly(System.Reflection.Assembly.Load(assemblyName), collect); 463 } 464 catch (Exception) 465 { 466 if (BurstLoader.IsDebugging) 467 { 468 Debug.LogWarning("Could not load assembly " + assemblyName); 469 } 470 } 471 } 472 } 473 474 private static bool IsNUnitDll(string value) 475 { 476 return CultureInfo.InvariantCulture.CompareInfo.IndexOf(value, "nunit.framework") >= 0; 477 } 478 479 private static void CollectAssemblyNames(UnityEditor.Compilation.Assembly assembly, HashSet<string> collect) 480 { 481 if (assembly == null || assembly.name == null) return; 482 483 if (!collect.Add(assembly.name)) 484 { 485 return; 486 } 487 488 foreach (var assemblyRef in assembly.assemblyReferences) 489 { 490 CollectAssemblyNames(assemblyRef, collect); 491 } 492 } 493 494 /// <summary> 495 /// Gets the list of concrete generic type instances used in an assembly. 496 /// See remarks 497 /// </summary> 498 /// <param name="assembly">The assembly</param> 499 /// <param name="types"></param> 500 /// <returns>The list of generic type instances</returns> 501 /// <remarks> 502 /// Note that this method fetchs only direct type instances but 503 /// cannot fetch transitive generic type instances. 504 /// </remarks> 505 private static void CollectGenericTypeInstances( 506 System.Reflection.Assembly assembly, 507 Func<Type, bool> typeFilter, 508 List<Type> types, 509 HashSet<Type> visited) 510 { 511 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 512 // WARNING: THIS CODE HAS TO BE MAINTAINED IN SYNC WITH BclApp.cs 513 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 514 515 // From: https://gist.github.com/xoofx/710aaf86e0e8c81649d1261b1ef9590e 516 if (assembly == null) throw new ArgumentNullException(nameof(assembly)); 517 const int mdMaxCount = 1 << 24; 518 foreach (var module in assembly.Modules) 519 { 520 for (int i = 1; i < mdMaxCount; i++) 521 { 522 try 523 { 524 // Token base id for TypeSpec 525 const int mdTypeSpec = 0x1B000000; 526 var type = module.ResolveType(mdTypeSpec | i); 527 CollectGenericTypeInstances(type, types, visited, typeFilter); 528 } 529 catch (ArgumentOutOfRangeException) 530 { 531 break; 532 } 533 catch (ArgumentException) 534 { 535 // Can happen on ResolveType on certain generic types, so we continue 536 } 537 } 538 539 for (int i = 1; i < mdMaxCount; i++) 540 { 541 try 542 { 543 // Token base id for MethodSpec 544 const int mdMethodSpec = 0x2B000000; 545 var method = module.ResolveMethod(mdMethodSpec | i); 546 var genericArgs = method.GetGenericArguments(); 547 foreach (var genArgType in genericArgs) 548 { 549 CollectGenericTypeInstances(genArgType, types, visited, typeFilter); 550 } 551 } 552 catch (ArgumentOutOfRangeException) 553 { 554 break; 555 } 556 catch (ArgumentException) 557 { 558 // Can happen on ResolveType on certain generic types, so we continue 559 } 560 } 561 562 for (int i = 1; i < mdMaxCount; i++) 563 { 564 try 565 { 566 // Token base id for Field 567 const int mdField = 0x04000000; 568 var field = module.ResolveField(mdField | i); 569 CollectGenericTypeInstances(field.FieldType, types, visited, typeFilter); 570 } 571 catch (ArgumentOutOfRangeException) 572 { 573 break; 574 } 575 catch (ArgumentException) 576 { 577 // Can happen on ResolveType on certain generic types, so we continue 578 } 579 } 580 } 581 582 // Scan for types used in constructor arguments to assembly-level attributes, 583 // such as [RegisterGenericJobType(typeof(...))]. 584 foreach (var customAttribute in assembly.CustomAttributes) 585 { 586 foreach (var argument in customAttribute.ConstructorArguments) 587 { 588 if (argument.ArgumentType == typeof(Type)) 589 { 590 CollectGenericTypeInstances((Type)argument.Value, types, visited, typeFilter); 591 } 592 } 593 } 594 } 595 596 private static void CollectGenericTypeInstances( 597 Type type, 598 List<Type> types, 599 HashSet<Type> visited, 600 Func<Type, bool> typeFilter) 601 { 602 if (type.IsPrimitive) return; 603 if (!visited.Add(type)) return; 604 605 // Add only concrete types 606 if (type.IsConstructedGenericType && !type.ContainsGenericParameters && typeFilter(type)) 607 { 608 types.Add(type); 609 } 610 611 // Collect recursively generic type arguments 612 var genericTypeArguments = type.GenericTypeArguments; 613 foreach (var genericTypeArgument in genericTypeArguments) 614 { 615 if (!genericTypeArgument.IsPrimitive) 616 { 617 CollectGenericTypeInstances(genericTypeArgument, types, visited, typeFilter); 618 } 619 } 620 } 621 622 [DebuggerDisplay("{Type} (static methods: {CollectStaticMethods})")] 623 private struct TypeToVisit 624 { 625 public TypeToVisit(Type type, bool collectStaticMethods) 626 { 627 Type = type; 628 CollectStaticMethods = collectStaticMethods; 629 } 630 631 public readonly Type Type; 632 633 public readonly bool CollectStaticMethods; 634 } 635 } 636 637 [Flags] 638 internal enum BurstReflectionAssemblyOptions 639 { 640 None = 0, 641 ExcludeTestAssemblies = 1, 642 } 643} 644#endif