From 91dcbf5b57f45419f1f12b6c0263fee54f389b95 Mon Sep 17 00:00:00 2001
From: Simon Cropp <simon.cropp@gmail.com>
Date: Sun, 6 Apr 2025 13:08:51 +1000
Subject: [PATCH 1/3] fix nullability of AssemblyEnumerator.GetTypes

---
 .../Discovery/AssemblyEnumerator.cs                 | 10 +++++-----
 .../MSTest.TestAdapter/Execution/TypeCache.cs       | 13 +++++++++----
 .../Discovery/AssemblyEnumeratorTests.cs            |  6 +++---
 3 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs b/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
index 7405c0eb7b..dba480a4b2 100644
--- a/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
+++ b/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
@@ -75,7 +75,7 @@ internal ICollection<UnitTestElement> EnumerateAssembly(
 
         Assembly assembly = PlatformServiceProvider.Instance.FileOperations.LoadAssembly(assemblyFileName, isReflectionOnly: false);
 
-        Type[] types = GetTypes(assembly, assemblyFileName, warnings);
+        Type?[] types = GetTypes(assembly, assemblyFileName, warnings);
         bool discoverInternals = ReflectHelper.GetDiscoverInternalsAttribute(assembly) != null;
         TestIdGenerationStrategy testIdGenerationStrategy = ReflectHelper.GetTestIdGenerationStrategy(assembly);
 
@@ -105,7 +105,7 @@ internal ICollection<UnitTestElement> EnumerateAssembly(
             },
         };
 
-        foreach (Type type in types)
+        foreach (Type? type in types)
         {
             if (type == null)
             {
@@ -127,7 +127,7 @@ internal ICollection<UnitTestElement> EnumerateAssembly(
     /// <param name="assemblyFileName">The file name of the assembly.</param>
     /// <param name="warningMessages">Contains warnings if any, that need to be passed back to the caller.</param>
     /// <returns>Gets the types defined in the provided assembly.</returns>
-    internal static Type[] GetTypes(Assembly assembly, string assemblyFileName, ICollection<string>? warningMessages)
+    internal static Type?[] GetTypes(Assembly assembly, string assemblyFileName, ICollection<string>? warningMessages)
     {
         try
         {
@@ -138,7 +138,7 @@ internal static Type[] GetTypes(Assembly assembly, string assemblyFileName, ICol
             PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning($"MSTestExecutor.TryGetTests: {Resource.TestAssembly_AssemblyDiscoveryFailure}", assemblyFileName, ex);
             PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.ExceptionsThrown);
 
-            if (ex.LoaderExceptions != null)
+            if (ex.LoaderExceptions.Any())
             {
                 // If not able to load all type, log a warning and continue with loaded types.
                 string message = string.Format(CultureInfo.CurrentCulture, Resource.TypeLoadFailed, assemblyFileName, GetLoadExceptionDetails(ex));
@@ -151,7 +151,7 @@ internal static Type[] GetTypes(Assembly assembly, string assemblyFileName, ICol
                 }
             }
 
-            return ex.Types!;
+            return ex.Types;
         }
     }
 
diff --git a/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs b/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
index 55f1cc63c6..38036a20b0 100644
--- a/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
+++ b/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
@@ -384,14 +384,19 @@ private TestAssemblyInfo GetAssemblyInfo(Assembly assembly)
             {
                 var assemblyInfo = new TestAssemblyInfo(assembly);
 
-                Type[] types = AssemblyEnumerator.GetTypes(assembly, assembly.FullName!, null);
+                Type?[] types = AssemblyEnumerator.GetTypes(assembly, assembly.FullName!, null);
 
-                foreach (Type t in types)
+                foreach (Type? type in types)
                 {
+                    if (type == null)
+                    {
+                        continue;
+                    }
+
                     try
                     {
                         // Only examine classes which are TestClass or derives from TestClass attribute
-                        if (!@this._reflectionHelper.IsDerivedAttributeDefined<TestClassAttribute>(t, inherit: false))
+                        if (!@this._reflectionHelper.IsDerivedAttributeDefined<TestClassAttribute>(type, inherit: false))
                         {
                             continue;
                         }
@@ -401,7 +406,7 @@ private TestAssemblyInfo GetAssemblyInfo(Assembly assembly)
                         // If we fail to discover type from an assembly, then do not abort. Pick the next type.
                         PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(
                             "TypeCache: Exception occurred while checking whether type {0} is a test class or not. {1}",
-                            t.FullName,
+                            type.FullName,
                             ex);
 
                         continue;
diff --git a/test/UnitTests/MSTestAdapter.UnitTests/Discovery/AssemblyEnumeratorTests.cs b/test/UnitTests/MSTestAdapter.UnitTests/Discovery/AssemblyEnumeratorTests.cs
index 88d9038478..750177ccf3 100644
--- a/test/UnitTests/MSTestAdapter.UnitTests/Discovery/AssemblyEnumeratorTests.cs
+++ b/test/UnitTests/MSTestAdapter.UnitTests/Discovery/AssemblyEnumeratorTests.cs
@@ -97,7 +97,7 @@ public void GetTypesShouldReturnSetOfDefinedTypes()
         // Setup mocks
         mockAssembly.Setup(a => a.GetTypes()).Returns(expectedTypes);
 
-        IReadOnlyList<Type> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, string.Empty, _warnings);
+        IReadOnlyList<Type?> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, string.Empty, _warnings);
         Verify(expectedTypes.SequenceEqual(types));
     }
 
@@ -119,7 +119,7 @@ public void GetTypesShouldReturnReflectionTypeLoadExceptionTypesOnException()
         // Setup mocks
         mockAssembly.Setup(a => a.GetTypes()).Throws(new ReflectionTypeLoadException(reflectedTypes, null));
 
-        IReadOnlyList<Type> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, string.Empty, _warnings);
+        IReadOnlyList<Type?> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, string.Empty, _warnings);
 
         Verify(types is not null);
         Verify(reflectedTypes.Equals(types));
@@ -134,7 +134,7 @@ public void GetTypesShouldLogWarningsWhenReflectionFailsWithLoaderExceptions()
         mockAssembly.Setup(a => a.GetTypes()).Throws(new ReflectionTypeLoadException(null, exceptions));
         mockAssembly.Setup(a => a.GetTypes()).Throws(new ReflectionTypeLoadException(null, exceptions));
 
-        IReadOnlyList<Type> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, "DummyAssembly", _warnings);
+        IReadOnlyList<Type?> types = AssemblyEnumerator.GetTypes(mockAssembly.Object, "DummyAssembly", _warnings);
 
         Verify(_warnings.Count == 1);
         Verify(_warnings.ToList().Contains(

From 2ebfc48dfbca9d32ab87fa0a7ae02237dd5f6eb4 Mon Sep 17 00:00:00 2001
From: Simon Cropp <simon.cropp@gmail.com>
Date: Thu, 10 Apr 2025 21:58:59 +1000
Subject: [PATCH 2/3] Update TypeCache.cs

---
 src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs b/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
index 38036a20b0..c9b19a99c6 100644
--- a/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
+++ b/src/Adapter/MSTest.TestAdapter/Execution/TypeCache.cs
@@ -413,7 +413,7 @@ private TestAssemblyInfo GetAssemblyInfo(Assembly assembly)
                     }
 
                     // Enumerate through all methods and identify the Assembly Init and cleanup methods.
-                    foreach (MethodInfo methodInfo in PlatformServiceProvider.Instance.ReflectionOperations.GetDeclaredMethods(t))
+                    foreach (MethodInfo methodInfo in PlatformServiceProvider.Instance.ReflectionOperations.GetDeclaredMethods(type))
                     {
                         if (@this.IsAssemblyOrClassInitializeMethod<AssemblyInitializeAttribute>(methodInfo))
                         {

From b6e363791afa03351ebaf05a8a30b4ef30cf7f6e Mon Sep 17 00:00:00 2001
From: Simon Cropp <simon.cropp@gmail.com>
Date: Thu, 10 Apr 2025 22:00:10 +1000
Subject: [PATCH 3/3] Update AssemblyEnumerator.cs

---
 src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs b/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
index dba480a4b2..5edf0d4cb6 100644
--- a/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
+++ b/src/Adapter/MSTest.TestAdapter/Discovery/AssemblyEnumerator.cs
@@ -151,7 +151,7 @@ internal ICollection<UnitTestElement> EnumerateAssembly(
                 }
             }
 
-            return ex.Types;
+            return ex.Types!;
         }
     }