diff --git a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs index 0adf4b89c7..daac82b9e0 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerationHelpers.cs @@ -109,7 +109,7 @@ public static string GenerateAttributeInstantiation(AttributeData attr, Immutabl var arg = attr.ConstructorArguments[i]; // Check if this is a params array parameter - if (i == attr.ConstructorArguments.Length - 1 && IsParamsArrayArgument(attr, i)) + if (i == attr.ConstructorArguments.Length - 1 && IsParamsArrayArgument(attr)) { if (arg.Kind == TypedConstantKind.Array) { @@ -282,16 +282,11 @@ public static string GenerateAttributeInstantiation(AttributeData attr, Immutabl /// /// Determines if an argument is for a params array parameter. /// - private static bool IsParamsArrayArgument(AttributeData attr, int argumentIndex) + private static bool IsParamsArrayArgument(AttributeData attr) { var typeName = attr.AttributeClass!.GloballyQualified(); - if (typeName is "global::TUnit.Core.ArgumentsAttribute" or "global::TUnit.Core.InlineDataAttribute") - { - return true; - } - - return false; + return typeName is "global::TUnit.Core.ArgumentsAttribute" or "global::TUnit.Core.InlineDataAttribute"; } diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs index 3109b51e3f..daff8f2a26 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/InstanceFactoryGenerator.cs @@ -81,7 +81,7 @@ public static void GenerateInstanceFactory(CodeWriter writer, ITypeSymbol typeSy var constructor = GetPrimaryConstructor(typeSymbol); if (constructor != null) { - GenerateTypedConstructorCall(writer, className, constructor, testMethod); + GenerateTypedConstructorCall(writer, className, constructor); } else { @@ -145,7 +145,7 @@ public static void GenerateInstanceFactory(CodeWriter writer, ITypeSymbol typeSy return publicConstructors.Length == 1 ? publicConstructors[0] : publicConstructors.FirstOrDefault(); } - private static void GenerateTypedConstructorCall(CodeWriter writer, string className, IMethodSymbol constructor, TestMethodMetadata? testMethod) + private static void GenerateTypedConstructorCall(CodeWriter writer, string className, IMethodSymbol constructor) { writer.AppendLine("InstanceFactory = (typeArgs, args) =>"); writer.AppendLine("{"); @@ -164,17 +164,17 @@ private static void GenerateTypedConstructorCall(CodeWriter writer, string class // Generate constructor arguments var parameterTypes = constructor.Parameters.Select(p => p.Type).ToList(); - + for (var i = 0; i < parameterTypes.Count; i++) { if (i > 0) { writer.Append(", "); } - + var parameterType = parameterTypes[i]; var argAccess = $"args[{i}]"; - + // Use CastHelper which now has AOT converter registry support writer.Append($"global::TUnit.Core.Helpers.CastHelper.Cast<{parameterType.GloballyQualified()}>({argAccess})"); } @@ -214,11 +214,11 @@ private static void GenerateGenericInstanceFactory(CodeWriter writer, INamedType // Get the open generic type writer.AppendLine($"var openGenericType = typeof({genericType.OriginalDefinition.GloballyQualified()});"); writer.AppendLine(); - + // Create the closed generic type writer.AppendLine("var closedGenericType = global::TUnit.Core.Helpers.GenericTypeHelper.MakeGenericTypeSafe(openGenericType, typeArgs);"); writer.AppendLine(); - + // Check for constructor parameters var constructor = GetPrimaryConstructor(genericType); if (constructor is { Parameters.Length: > 0 }) @@ -230,7 +230,7 @@ private static void GenerateGenericInstanceFactory(CodeWriter writer, INamedType { writer.AppendLine("// Create instance with parameterless constructor"); writer.AppendLine("var instance = global::System.Activator.CreateInstance(closedGenericType);"); - + // Check for required properties var requiredProperties = RequiredPropertyHelper.GetAllRequiredProperties(genericType); if (requiredProperties.Any()) @@ -243,7 +243,7 @@ private static void GenerateGenericInstanceFactory(CodeWriter writer, INamedType writer.AppendLine($"closedGenericType.GetProperty(\"{property.Name}\")?.SetValue(instance, {defaultValue});"); } } - + writer.AppendLine("return instance!;"); } diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/StaticPropertyInitializationGenerator.cs b/TUnit.Core.SourceGenerator/CodeGenerators/StaticPropertyInitializationGenerator.cs index ca4808fda7..05f435af76 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/StaticPropertyInitializationGenerator.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/StaticPropertyInitializationGenerator.cs @@ -72,7 +72,7 @@ private static void GenerateStaticPropertyInitialization(SourceProductionContext // Use a dictionary to deduplicate static properties by their declaring type and name // This prevents duplicate initialization when derived classes inherit static properties var uniqueStaticProperties = new Dictionary<(INamedTypeSymbol DeclaringType, string Name), PropertyWithDataSource>(SymbolEqualityComparer.Default.ToTupleComparer()); - + foreach (var testClass in testClasses) { var properties = GetStaticPropertyDataSources(testClass); @@ -87,7 +87,7 @@ private static void GenerateStaticPropertyInitialization(SourceProductionContext } } } - + var allStaticProperties = uniqueStaticProperties.Values.ToImmutableArray(); if (allStaticProperties.IsEmpty) @@ -109,14 +109,14 @@ private static string GenerateInitializationCode(ImmutableArray"); writer.AppendLine("/// Auto-generated static property initializer"); writer.AppendLine("/// "); writer.AppendLine("internal static class StaticPropertyInitializer"); writer.AppendLine("{"); writer.Indent(); - + writer.AppendLine("/// "); writer.AppendLine("/// Module initializer that registers static property metadata"); writer.AppendLine("/// "); @@ -179,7 +179,7 @@ private static void GenerateIndividualPropertyInitializer(CodeWriter writer, Pro writer.AppendLine($"private static async global::System.Threading.Tasks.Task {methodName}()"); writer.AppendLine("{"); writer.Indent(); - + // Create PropertyMetadata with containing type information writer.AppendLine($"// Create PropertyMetadata for {propertyName}"); writer.AppendLine("var containingTypeMetadata = new global::TUnit.Core.ClassMetadata"); @@ -196,7 +196,7 @@ private static void GenerateIndividualPropertyInitializer(CodeWriter writer, Pro writer.Unindent(); writer.AppendLine("};"); writer.AppendLine(); - + writer.AppendLine("var propertyMetadata = new global::TUnit.Core.PropertyMetadata"); writer.AppendLine("{"); writer.Indent(); @@ -210,7 +210,7 @@ private static void GenerateIndividualPropertyInitializer(CodeWriter writer, Pro writer.Unindent(); writer.AppendLine("};"); writer.AppendLine(); - + var attr = propertyData.DataSourceAttribute; var attributeClassName = attr.AttributeClass?.Name; @@ -230,7 +230,7 @@ private static void GenerateIndividualPropertyInitializer(CodeWriter writer, Pro else if (attr.AttributeClass?.IsOrInherits("global::TUnit.Core.AsyncDataSourceGeneratorAttribute") == true || attr.AttributeClass?.IsOrInherits("global::TUnit.Core.AsyncUntypedDataSourceGeneratorAttribute") == true) { - GenerateAsyncDataSourceGeneratorWithPropertyWithAssignment(writer, attr, propertyData.Property.ContainingType); + GenerateAsyncDataSourceGeneratorWithPropertyWithAssignment(writer, attr); } else { @@ -306,7 +306,7 @@ private static void GenerateMethodDataSourceWithAssignment(CodeWriter writer, At } - private static void GenerateAsyncDataSourceGeneratorWithPropertyWithAssignment(CodeWriter writer, AttributeData attr, INamedTypeSymbol containingType) + private static void GenerateAsyncDataSourceGeneratorWithPropertyWithAssignment(CodeWriter writer, AttributeData attr) { var generatorCode = CodeGenerationHelpers.GenerateAttributeInstantiation(attr); writer.AppendLine($"var generator = {generatorCode};"); @@ -375,4 +375,4 @@ private static ImmutableArray GetStaticPropertyDataSourc return properties.ToImmutableArray(); } -} \ No newline at end of file +} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs index d3e0aa4c44..ab2323c2a8 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs @@ -11,7 +11,7 @@ public class AttributeWriter(Compilation compilation, TUnit.Core.SourceGenerator private readonly Dictionary _attributeObjectInitializerCache = new(); public void WriteAttributes(ICodeWriter sourceCodeWriter, - ImmutableArray attributeDatas) + IEnumerable attributeDatas) { var attributesToWrite = new List(); diff --git a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs index 27d59e63e7..e6675284c4 100644 --- a/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/HookMetadataGenerator.cs @@ -159,7 +159,7 @@ private static void GenerateHookFile(SourceProductionContext context, (HookModel var isOpenGeneric = typeSymbol.IsGenericType && typeSymbol.TypeArguments.Any(t => t.TypeKind == TypeKind.TypeParameter); // Pre-generate method info expression - var methodInfoExpression = GenerateMethodInfoExpression(context.SemanticModel.Compilation, typeSymbol, methodSymbol); + var methodInfoExpression = GenerateMethodInfoExpression(typeSymbol, methodSymbol); // Extract parameters using the existing static method var parameters = ParameterModel.ExtractAll(methodSymbol); @@ -200,7 +200,7 @@ private static void GenerateHookFile(SourceProductionContext context, (HookModel }; } - private static string GenerateMethodInfoExpression(Compilation compilation, INamedTypeSymbol typeSymbol, IMethodSymbol methodSymbol) + private static string GenerateMethodInfoExpression(INamedTypeSymbol typeSymbol, IMethodSymbol methodSymbol) { // Generate the MethodMetadata expression as a string - no header since this is inline using var writer = new CodeWriter(includeHeader: false); @@ -632,11 +632,11 @@ private static void GenerateHookDelegate(CodeWriter writer, HookModel hook) { if (hook.ClassIsOpenGeneric) { - GenerateReflectionBasedInvocation(writer, hook, true); + GenerateReflectionBasedInvocation(writer, hook); } else { - GenerateDirectInvocation(writer, hook, true); + GenerateDirectInvocation(writer, hook); } } } @@ -644,7 +644,7 @@ private static void GenerateHookDelegate(CodeWriter writer, HookModel hook) { using (writer.BeginBlock($"private static async ValueTask {delegateKey}_Body({contextType} context, CancellationToken cancellationToken)")) { - if (hook.ClassIsOpenGeneric && hook.IsStatic) + if (hook is { ClassIsOpenGeneric: true, IsStatic: true }) { GenerateOpenGenericStaticInvocation(writer, hook); } @@ -658,7 +658,7 @@ private static void GenerateHookDelegate(CodeWriter writer, HookModel hook) writer.AppendLine(); } - private static void GenerateReflectionBasedInvocation(CodeWriter writer, HookModel hook, bool isInstance) + private static void GenerateReflectionBasedInvocation(CodeWriter writer, HookModel hook) { writer.AppendLine("var instanceType = instance.GetType();"); writer.AppendLine($"var method = instanceType.GetMethod(\"{hook.MethodName}\", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance{(hook.IsStatic ? " | global::System.Reflection.BindingFlags.Static" : "")});"); @@ -700,7 +700,7 @@ private static void GenerateReflectionBasedInvocation(CodeWriter writer, HookMod writer.AppendLine("}"); } - private static void GenerateDirectInvocation(CodeWriter writer, HookModel hook, bool isInstance) + private static void GenerateDirectInvocation(CodeWriter writer, HookModel hook) { var className = hook.FullyQualifiedTypeName; writer.AppendLine($"var typedInstance = ({className})instance;"); diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index b835a60e99..f8ecf9d395 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -300,7 +300,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t writer.AppendLine("{"); writer.Indent(); - GenerateReflectionFieldAccessors(writer, testMethod.TypeSymbol, className); + GenerateReflectionFieldAccessors(writer, testMethod.TypeSymbol); writer.AppendLine("public async global::System.Collections.Generic.IAsyncEnumerable GetTestsAsync(string testSessionId, [global::System.Runtime.CompilerServices.EnumeratorCancellation] global::System.Threading.CancellationToken cancellationToken = default)"); writer.AppendLine("{"); @@ -346,12 +346,12 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t { // For generic classes with no way to resolve type arguments, this will generate // GenericTestMetadata that the engine will fail with a clear error message - GenerateTestMetadataInstance(writer, testMethod, className, uniqueClassName); + GenerateTestMetadataInstance(writer, testMethod, className); } } else { - GenerateTestMetadataInstance(writer, testMethod, className, uniqueClassName); + GenerateTestMetadataInstance(writer, testMethod, className); } writer.AppendLine("yield break;"); @@ -361,7 +361,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t writer.AppendLine(); // Generate EnumerateTestDescriptors method for fast filtering - GenerateEnumerateTestDescriptors(writer, testMethod, className); + GenerateEnumerateTestDescriptors(writer, testMethod); writer.Unindent(); writer.AppendLine("}"); @@ -369,7 +369,7 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t GenerateModuleInitializer(writer, testMethod, uniqueClassName); } - private static void GenerateTestMetadataInstance(CodeWriter writer, TestMethodMetadata testMethod, string className, string combinationGuid) + private static void GenerateTestMetadataInstance(CodeWriter writer, TestMethodMetadata testMethod, string className) { var methodName = testMethod.MethodSymbol.Name; @@ -436,8 +436,7 @@ private static void GenerateMetadata(CodeWriter writer, TestMethodMetadata testM var attributes = methodSymbol.GetAttributes() .Where(a => !DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) .Concat(testMethod.TypeSymbol.GetAttributesIncludingBaseTypes()) - .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) - .ToImmutableArray(); + .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()); testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes); @@ -482,8 +481,7 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer, var attributes = methodSymbol.GetAttributes() .Where(a => !DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) .Concat(testMethod.TypeSymbol.GetAttributesIncludingBaseTypes()) - .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) - .ToImmutableArray(); + .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()); testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes); @@ -518,22 +516,21 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer, private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata testMethod) { - var compilation = testMethod.Context!.Value.SemanticModel.Compilation; var methodSymbol = testMethod.MethodSymbol; var typeSymbol = testMethod.TypeSymbol; // Extract data source attributes from method var methodDataSources = methodSymbol.GetAttributes() .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) - .ToList(); + .ToArray(); // Extract data source attributes from class var classDataSources = typeSymbol.GetAttributesIncludingBaseTypes() .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) - .ToList(); + .ToArray(); // Generate method data sources - if (methodDataSources.Count == 0) + if (methodDataSources.Length == 0) { writer.AppendLine("DataSources = global::System.Array.Empty(),"); } @@ -553,7 +550,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te } // Generate class data sources - if (classDataSources.Count == 0) + if (classDataSources.Length == 0) { writer.AppendLine("ClassDataSources = global::System.Array.Empty(),"); } @@ -588,7 +585,7 @@ private static void GenerateDataSourceAttribute(CodeWriter writer, CompilationCo if (attrName == "global::TUnit.Core.MethodDataSourceAttribute") { - GenerateMethodDataSourceAttribute(writer, attr, methodSymbol, typeSymbol); + GenerateMethodDataSourceAttribute(writer, attr, typeSymbol); } else if (attrName == "global::TUnit.Core.ArgumentsAttribute") { @@ -669,9 +666,9 @@ private static void GenerateArgumentsAttributeWithParameterTypes(CodeWriter writ writer.Append($"new {attrTypeName}("); // Only process positional arguments (exclude named arguments) - var positionalArgs = argumentList.Arguments.Where(a => a.NameEquals == null).ToList(); + var positionalArgs = argumentList.Arguments.Where(a => a.NameEquals == null).ToArray(); - for (var i = 0; i < positionalArgs.Count; i++) + for (var i = 0; i < positionalArgs.Length; i++) { var argumentSyntax = positionalArgs[i]; var expression = argumentSyntax.Expression; @@ -697,7 +694,7 @@ private static void GenerateArgumentsAttributeWithParameterTypes(CodeWriter writ writer.Append(fullyQualifiedExpression.ToFullString()); } - if (i < positionalArgs.Count - 1) + if (i < positionalArgs.Length - 1) { writer.Append(", "); } @@ -706,21 +703,21 @@ private static void GenerateArgumentsAttributeWithParameterTypes(CodeWriter writ writer.Append(")"); // Handle named arguments (like Skip property) - var namedArgs = argumentList.Arguments.Where(a => a.NameEquals != null).ToList(); - if (namedArgs.Count > 0) + var namedArgs = argumentList.Arguments.Where(a => a.NameEquals != null).ToArray(); + if (namedArgs.Length > 0) { writer.AppendLine(); writer.AppendLine("{"); writer.Indent(); - for (var i = 0; i < namedArgs.Count; i++) + for (var i = 0; i < namedArgs.Length; i++) { var namedArg = namedArgs[i]; var propertyName = namedArg.NameEquals!.Name.ToString(); var fullyQualifiedExpression = namedArg.Expression.Accept(new FullyQualifiedWithGlobalPrefixRewriter(semanticModel))!; writer.Append($"{propertyName} = {fullyQualifiedExpression.ToFullString()}"); - if (i < namedArgs.Count - 1) + if (i < namedArgs.Length - 1) { writer.AppendLine(","); } @@ -736,7 +733,7 @@ private static void GenerateArgumentsAttributeWithParameterTypes(CodeWriter writ } } - private static void GenerateMethodDataSourceAttribute(CodeWriter writer, AttributeData attr, IMethodSymbol methodSymbol, INamedTypeSymbol typeSymbol) + private static void GenerateMethodDataSourceAttribute(CodeWriter writer, AttributeData attr, INamedTypeSymbol typeSymbol) { // Extract method name and target type string? methodName = null; @@ -843,11 +840,11 @@ private static void GenerateMethodDataSourceAttribute(CodeWriter writer, Attribu // Generate the factory implementation if (dataSourceMethod != null) { - GenerateMethodDataSourceFactory(writer, dataSourceMethod, targetType, methodSymbol, attr, hasArguments); + GenerateMethodDataSourceFactory(writer, dataSourceMethod, targetType, hasArguments); } else if (dataSourceProperty != null) { - GeneratePropertyDataSourceFactory(writer, dataSourceProperty, targetType, methodSymbol, attr); + GeneratePropertyDataSourceFactory(writer, dataSourceProperty, targetType); } writer.Unindent(); @@ -857,7 +854,7 @@ private static void GenerateMethodDataSourceAttribute(CodeWriter writer, Attribu writer.AppendLine("},"); } - private static void GenerateMethodDataSourceFactory(CodeWriter writer, IMethodSymbol dataSourceMethod, ITypeSymbol targetType, IMethodSymbol testMethod, AttributeData attr, bool hasArguments) + private static void GenerateMethodDataSourceFactory(CodeWriter writer, IMethodSymbol dataSourceMethod, ITypeSymbol targetType, bool hasArguments) { var isStatic = dataSourceMethod.IsStatic; var returnType = dataSourceMethod.ReturnType; @@ -1041,7 +1038,7 @@ private static void GenerateMethodDataSourceFactory(CodeWriter writer, IMethodSy writer.AppendLine("return Factory();"); } - private static void GeneratePropertyDataSourceFactory(CodeWriter writer, IPropertySymbol dataSourceProperty, ITypeSymbol targetType, IMethodSymbol testMethod, AttributeData attr) + private static void GeneratePropertyDataSourceFactory(CodeWriter writer, IPropertySymbol dataSourceProperty, ITypeSymbol targetType) { var isStatic = dataSourceProperty.IsStatic; var returnType = dataSourceProperty.Type; @@ -1538,7 +1535,7 @@ private static void GenerateNestedPropertyInjections(CodeWriter writer, ITypeSym writer.AppendLine("NestedPropertyInjections = new global::TUnit.Core.PropertyInjectionData[]"); writer.AppendLine("{"); writer.Indent(); - GeneratePropertyInjectionsForType(writer, propertyType, processedProperties, isNested: true); + GeneratePropertyInjectionsForType(writer, propertyType, processedProperties); writer.Unindent(); writer.AppendLine("},"); } @@ -1595,7 +1592,7 @@ private static bool ShouldGenerateNestedInjections(ITypeSymbol type) .Any(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass))); } - private static void GeneratePropertyInjectionsForType(CodeWriter writer, ITypeSymbol typeSymbol, HashSet processedProperties, bool isNested) + private static void GeneratePropertyInjectionsForType(CodeWriter writer, ITypeSymbol typeSymbol, HashSet processedProperties) { var currentType = typeSymbol; var nestedProcessedProperties = new HashSet(processedProperties); @@ -1768,12 +1765,11 @@ private static void GenerateTypedInvokers(CodeWriter writer, TestMethodMetadata var returnPattern = GetReturnPattern(testMethod.MethodSymbol); if (testMethod is { IsGenericType: false, IsGenericMethod: false }) { - GenerateConcreteTestInvoker(writer, testMethod, className, methodName, returnPattern, hasCancellationToken, parametersFromArgs); + GenerateConcreteTestInvoker(writer, methodName, returnPattern, hasCancellationToken, parametersFromArgs); } } - - private static void GenerateConcreteTestInvoker(CodeWriter writer, TestMethodMetadata testMethod, string className, string methodName, TestReturnPattern returnPattern, bool hasCancellationToken, IParameterSymbol[] parametersFromArgs) + private static void GenerateConcreteTestInvoker(CodeWriter writer, string methodName, TestReturnPattern returnPattern, bool hasCancellationToken, IParameterSymbol[] parametersFromArgs) { // Generate InvokeTypedTest which is required by CreateExecutableTestFactory writer.AppendLine("InvokeTypedTest = static (instance, args, cancellationToken) =>"); @@ -1807,7 +1803,7 @@ private static void GenerateConcreteTestInvoker(CodeWriter writer, TestMethodMet // Build tuple reconstruction with proper casting var tupleElements = singleTupleParam.TupleElements.Select((elem, i) => - $"global::TUnit.Core.Helpers.CastHelper.Cast<{elem.Type.GloballyQualified()}>(args[{i}])").ToList(); + $"global::TUnit.Core.Helpers.CastHelper.Cast<{elem.Type.GloballyQualified()}>(args[{i}])"); var tupleConstruction = $"({string.Join(", ", tupleElements)})"; var methodCallReconstructed = hasCancellationToken @@ -1911,7 +1907,7 @@ private static void GenerateConcreteTestInvoker(CodeWriter writer, TestMethodMet writer.AppendLine("},"); } - private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMethodMetadata testMethod, string className) + private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMethodMetadata testMethod) { var methodName = testMethod.MethodSymbol.Name; var namespaceName = testMethod.TypeSymbol.ContainingNamespace?.ToDisplayString() ?? ""; @@ -1922,7 +1918,7 @@ private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMeth // Extract categories from CategoryAttribute at compile time var categories = ExtractCategories(testMethod); - var categoriesArray = categories.Length == 0 + var categoriesArray = categories.Count == 0 ? "global::System.Array.Empty()" : $"new string[] {{ {string.Join(", ", categories.Select(c => $"\"{EscapeString(c)}\""))} }}"; @@ -1972,9 +1968,9 @@ private static void GenerateEnumerateTestDescriptors(CodeWriter writer, TestMeth writer.AppendLine("}"); } - private static string[] ExtractCategories(TestMethodMetadata testMethod) + private static HashSet ExtractCategories(TestMethodMetadata testMethod) { - var categories = new List(); + var categories = new HashSet(); // Check method attributes foreach (var attr in testMethod.MethodAttributes) @@ -2009,7 +2005,7 @@ private static string[] ExtractCategories(TestMethodMetadata testMethod) } } - return categories.Distinct().ToArray(); + return categories; } private static string[] ExtractProperties(TestMethodMetadata testMethod) @@ -2020,8 +2016,7 @@ private static string[] ExtractProperties(TestMethodMetadata testMethod) foreach (var attr in testMethod.MethodAttributes) { if (attr.AttributeClass?.Name == "PropertyAttribute" && - attr.ConstructorArguments.Length >= 2 && - attr.ConstructorArguments[0].Value is string key && + attr.ConstructorArguments is [{ Value: string key } _, _, ..] && attr.ConstructorArguments[1].Value is string value) { properties.Add($"{key}={value}"); @@ -2301,7 +2296,7 @@ private static void GenerateDependencies(CodeWriter writer, IMethodSymbol method .Concat(methodSymbol.ContainingType.GetAttributes()) .Where(attr => attr.AttributeClass?.Name == "DependsOnAttribute" && attr.AttributeClass.ContainingNamespace?.ToDisplayString() == "TUnit.Core") - .ToList(); + .ToArray(); if (!dependsOnAttributes.Any()) { @@ -2313,12 +2308,12 @@ private static void GenerateDependencies(CodeWriter writer, IMethodSymbol method writer.AppendLine("{"); writer.Indent(); - for (var i = 0; i < dependsOnAttributes.Count; i++) + for (var i = 0; i < dependsOnAttributes.Length; i++) { var attr = dependsOnAttributes[i]; GenerateTestDependency(writer, attr); - if (i < dependsOnAttributes.Count - 1) + if (i < dependsOnAttributes.Length - 1) { writer.AppendLine(","); } @@ -2549,12 +2544,12 @@ private static List CollectInheritedTestMethods(INamedTypeSymbol var allTestMethods = derivedClass.GetMembersIncludingBase() .OfType() .Where(m => m.GetAttributes().Any(attr => attr.IsTestAttribute())) - .ToList(); + .ToArray(); // Find methods declared directly on the derived class var derivedClassMethods = allTestMethods .Where(m => SymbolEqualityComparer.Default.Equals(m.ContainingType.OriginalDefinition, derivedClass.OriginalDefinition)) - .ToList(); + .ToArray(); // Filter out base methods that are hidden by derived class methods or declared directly on derived class var result = new List(); @@ -2706,7 +2701,7 @@ private static (string filePath, int lineNumber) GetTestMethodSourceLocation( return (derivedFilePath, derivedLineNumber); } - private static void GenerateReflectionFieldAccessors(CodeWriter writer, INamedTypeSymbol typeSymbol, string className) + private static void GenerateReflectionFieldAccessors(CodeWriter writer, INamedTypeSymbol typeSymbol) { // Find all init-only properties with data source attributes var initOnlyPropertiesWithDataSources = new List(); @@ -2734,7 +2729,7 @@ private static void GenerateReflectionFieldAccessors(CodeWriter writer, INamedTy // Skip for generic types since UnsafeAccessor doesn't work with open generic types var nonGenericProperties = initOnlyPropertiesWithDataSources .Where(p => !p.ContainingType.IsGenericType) - .ToList(); + .ToArray(); if (nonGenericProperties.Any()) { @@ -2898,7 +2893,7 @@ private static void GenerateGenericParameterConstraints(CodeWriter writer, IType } // Generate interface constraints - var interfaceConstraints = typeParam.ConstraintTypes.Where(c => c.TypeKind == TypeKind.Interface).ToArray(); + var interfaceConstraints = typeParam.ConstraintTypes.Where(c => c.TypeKind == TypeKind.Interface); writer.AppendLine("InterfaceConstraints = new global::System.Type[]"); writer.AppendLine("{"); writer.Indent(); @@ -2914,26 +2909,6 @@ private static void GenerateGenericParameterConstraints(CodeWriter writer, IType writer.AppendLine("},"); } - private static bool ContainsGenericTypeParameter(ITypeSymbol type) - { - if (type.TypeKind == TypeKind.TypeParameter) - { - return true; - } - - if (type is INamedTypeSymbol namedType) - { - return namedType.TypeArguments.Any(ContainsGenericTypeParameter); - } - - if (type is IArrayTypeSymbol arrayType) - { - return ContainsGenericTypeParameter(arrayType.ElementType); - } - - return false; - } - private static void GenerateGenericTestWithConcreteTypes( CodeWriter writer, TestMethodMetadata testMethod, @@ -3000,21 +2975,20 @@ private static void GenerateGenericTestWithConcreteTypes( var methodArgumentsAttributes = testMethod.MethodAttributes .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); + .ToArray(); - var classArgumentsAttributes = new List(); + var classArgumentsAttributes = Array.Empty(); // For generic classes, collect class-level Arguments attributes separately if (testMethod.IsGenericType) { classArgumentsAttributes = testMethod.TypeSymbol.GetAttributes() .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); + .ToArray(); } var processedTypeCombinations = new HashSet(); - // Handle the combination of class and method Arguments attributes if (testMethod is { IsGenericType: true, IsGenericMethod: true } && classArgumentsAttributes.Any() && methodArgumentsAttributes.Any()) { @@ -3131,7 +3105,7 @@ private static void GenerateGenericTestWithConcreteTypes( // Handle generic classes with non-generic methods that have method-level Arguments // These were skipped in the main loop and need special processing - if (testMethod is { IsGenericType: true, IsGenericMethod: false } && methodArgumentsAttributes.Count > 0) + if (testMethod is { IsGenericType: true, IsGenericMethod: false } && methodArgumentsAttributes.Length > 0) { foreach (var methodArgAttr in methodArgumentsAttributes) { @@ -3164,8 +3138,7 @@ private static void GenerateGenericTestWithConcreteTypes( // Process typed data source attributes var dataSourceAttributes = testMethod.MethodAttributes - .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) - .ToList(); + .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)); foreach (var dataSourceAttr in dataSourceAttributes) { @@ -3232,8 +3205,7 @@ private static void GenerateGenericTestWithConcreteTypes( if (testMethod is { IsGenericType: true, IsGenericMethod: false }) { var methodDataSourceAttributes = testMethod.MethodAttributes - .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute") - .ToList(); + .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute"); foreach (var mdsAttr in methodDataSourceAttributes) { @@ -3286,8 +3258,7 @@ private static void GenerateGenericTestWithConcreteTypes( if (testMethod.IsGenericMethod) { var methodDataSourceAttributes = testMethod.MethodAttributes - .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute") - .ToList(); + .Where(a => a.AttributeClass?.Name == "MethodDataSourceAttribute"); foreach (var mdsAttr in methodDataSourceAttributes) { @@ -3317,9 +3288,7 @@ private static void GenerateGenericTestWithConcreteTypes( if (testMethod.IsGenericType) { var argumentsAttributes = testMethod.TypeSymbol.GetAttributes() - .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); - + .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute"); foreach (var argAttr in argumentsAttributes) { @@ -3332,8 +3301,7 @@ private static void GenerateGenericTestWithConcreteTypes( { // Get method-level Arguments attributes to infer method type parameters var methodLevelArgumentsAttributes = testMethod.MethodAttributes - .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); + .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute"); foreach (var methodArgAttr in methodLevelArgumentsAttributes) { @@ -3387,17 +3355,17 @@ private static void GenerateGenericTestWithConcreteTypes( { var nonGenericClassArguments = testMethod.TypeSymbol.GetAttributes() .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); + .ToArray(); var nonGenericMethodArguments = testMethod.MethodAttributes .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); + .ToArray(); // Also get class-level data source generators for non-generic classes var nonGenericClassDataSourceGenerators = testMethod.TypeSymbol.GetAttributes() .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass) && a.AttributeClass?.Name != "ArgumentsAttribute") - .ToList(); + .ToArray(); if (nonGenericClassArguments.Any() && nonGenericMethodArguments.Any()) { @@ -3450,14 +3418,12 @@ private static void GenerateGenericTestWithConcreteTypes( // Process GenerateGenericTest attributes from both class and method levels // When both class and method are generic, we need to generate cartesian products var methodGenerateGenericTestAttributes = testMethod.MethodAttributes - .Where(a => a.AttributeClass?.Name == "GenerateGenericTestAttribute") - .ToList(); + .Where(a => a.AttributeClass?.Name == "GenerateGenericTestAttribute"); var classLevelAttributes = testMethod.IsGenericType ? testMethod.TypeSymbol.GetAttributes() .Where(a => a.AttributeClass?.Name == "GenerateGenericTestAttribute") - .ToList() - : new List(); + : Array.Empty(); // Extract type arguments from class-level attributes var classTypeArgSets = ExtractTypeArgumentSets(classLevelAttributes); @@ -3536,7 +3502,7 @@ private static void GenerateGenericTestWithConcreteTypes( writer.AppendLine("yield return genericMetadata;"); } - private static List ExtractTypeArgumentSets(List attributes) + private static List ExtractTypeArgumentSets(IEnumerable attributes) { var result = new List(); @@ -3769,7 +3735,7 @@ private static bool TypeImplementsInterface(ITypeSymbol type, ITypeSymbol interf // Extract the concrete type and map it to the type parameter if (argValue.Type != null) { - MapGenericTypeArguments(methodParam.Type, argValue.Type, classSymbol, inferredTypes); + MapGenericTypeArguments(methodParam.Type, argValue.Type, inferredTypes); } } } @@ -3814,7 +3780,7 @@ private static bool ContainsClassTypeParameter(ITypeSymbol type, INamedTypeSymbo return false; } - private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol argType, INamedTypeSymbol classSymbol, Dictionary inferredTypes) + private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol argType, Dictionary inferredTypes) { if (paramType is ITypeParameterSymbol { DeclaringMethod: null } typeParam) { @@ -3830,64 +3796,27 @@ private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol a // Map type arguments recursively for (var i = 0; i < paramNamedType.TypeArguments.Length && i < argNamedType.TypeArguments.Length; i++) { - MapGenericTypeArguments(paramNamedType.TypeArguments[i], argNamedType.TypeArguments[i], classSymbol, inferredTypes); + MapGenericTypeArguments(paramNamedType.TypeArguments[i], argNamedType.TypeArguments[i], inferredTypes); } } } private static ITypeSymbol? InferTypeFromValue(object value, Compilation compilation) { - if (value is int) - { - return compilation?.GetSpecialType(SpecialType.System_Int32); - } - - if (value is string) - { - return compilation?.GetSpecialType(SpecialType.System_String); - } - - if (value is bool) - { - return compilation?.GetSpecialType(SpecialType.System_Boolean); - } - - if (value is double) - { - return compilation?.GetSpecialType(SpecialType.System_Double); - } - - if (value is float) - { - return compilation?.GetSpecialType(SpecialType.System_Single); - } - - if (value is long) - { - return compilation?.GetSpecialType(SpecialType.System_Int64); - } - - if (value is byte) - { - return compilation?.GetSpecialType(SpecialType.System_Byte); - } - - if (value is char) - { - return compilation?.GetSpecialType(SpecialType.System_Char); - } - - if (value is decimal) - { - return compilation?.GetSpecialType(SpecialType.System_Decimal); - } - - if (value is ITypeSymbol) - { - return compilation?.GetTypeByMetadataName("System.Type"); - } - - return null; + return value switch + { + int => compilation?.GetSpecialType(SpecialType.System_Int32), + string => compilation?.GetSpecialType(SpecialType.System_String), + bool => compilation?.GetSpecialType(SpecialType.System_Boolean), + double => compilation?.GetSpecialType(SpecialType.System_Double), + float => compilation?.GetSpecialType(SpecialType.System_Single), + long => compilation?.GetSpecialType(SpecialType.System_Int64), + byte => compilation?.GetSpecialType(SpecialType.System_Byte), + char => compilation?.GetSpecialType(SpecialType.System_Char), + decimal => compilation?.GetSpecialType(SpecialType.System_Decimal), + ITypeSymbol => compilation?.GetTypeByMetadataName("System.Type"), + _ => null + }; } private static ITypeSymbol[]? InferTypesFromClassArgumentsAttribute(INamedTypeSymbol classSymbol, AttributeData argAttr, Compilation compilation) @@ -3943,42 +3872,19 @@ private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol a // For literal values, infer type from the value var value = argValue.Value; - if (value is int) - { - argType = compilation?.GetSpecialType(SpecialType.System_Int32); - } - else if (value is string) - { - argType = compilation?.GetSpecialType(SpecialType.System_String); - } - else if (value is bool) - { - argType = compilation?.GetSpecialType(SpecialType.System_Boolean); - } - else if (value is double) - { - argType = compilation?.GetSpecialType(SpecialType.System_Double); - } - else if (value is float) - { - argType = compilation?.GetSpecialType(SpecialType.System_Single); - } - else if (value is long) - { - argType = compilation?.GetSpecialType(SpecialType.System_Int64); - } - else if (value is char) - { - argType = compilation?.GetSpecialType(SpecialType.System_Char); - } - else if (value is byte) - { - argType = compilation?.GetSpecialType(SpecialType.System_Byte); - } - else if (value is decimal) + argType = value switch { - argType = compilation?.GetSpecialType(SpecialType.System_Decimal); - } + int => compilation?.GetSpecialType(SpecialType.System_Int32), + string => compilation?.GetSpecialType(SpecialType.System_String), + bool => compilation?.GetSpecialType(SpecialType.System_Boolean), + double => compilation?.GetSpecialType(SpecialType.System_Double), + float => compilation?.GetSpecialType(SpecialType.System_Single), + long => compilation?.GetSpecialType(SpecialType.System_Int64), + char => compilation?.GetSpecialType(SpecialType.System_Char), + byte => compilation?.GetSpecialType(SpecialType.System_Byte), + decimal => compilation?.GetSpecialType(SpecialType.System_Decimal), + _ => argType + }; } if (argType != null) @@ -4056,42 +3962,19 @@ private static void MapGenericTypeArguments(ITypeSymbol paramType, ITypeSymbol a // For literal values, infer type from the value var value = argValue.Value; - if (value is int) - { - argType = compilation?.GetSpecialType(SpecialType.System_Int32); - } - else if (value is string) - { - argType = compilation?.GetSpecialType(SpecialType.System_String); - } - else if (value is bool) - { - argType = compilation?.GetSpecialType(SpecialType.System_Boolean); - } - else if (value is double) - { - argType = compilation?.GetSpecialType(SpecialType.System_Double); - } - else if (value is float) - { - argType = compilation?.GetSpecialType(SpecialType.System_Single); - } - else if (value is long) - { - argType = compilation?.GetSpecialType(SpecialType.System_Int64); - } - else if (value is char) + argType = value switch { - argType = compilation?.GetSpecialType(SpecialType.System_Char); - } - else if (value is byte) - { - argType = compilation?.GetSpecialType(SpecialType.System_Byte); - } - else if (value is decimal) - { - argType = compilation?.GetSpecialType(SpecialType.System_Decimal); - } + int => compilation?.GetSpecialType(SpecialType.System_Int32), + string => compilation?.GetSpecialType(SpecialType.System_String), + bool => compilation?.GetSpecialType(SpecialType.System_Boolean), + double => compilation?.GetSpecialType(SpecialType.System_Double), + float => compilation?.GetSpecialType(SpecialType.System_Single), + long => compilation?.GetSpecialType(SpecialType.System_Int64), + char => compilation?.GetSpecialType(SpecialType.System_Char), + byte => compilation?.GetSpecialType(SpecialType.System_Byte), + decimal => compilation?.GetSpecialType(SpecialType.System_Decimal), + _ => argType + }; } if (argType != null) @@ -4445,14 +4328,11 @@ private static bool ValidateTypeConstraints(IMethodSymbol method, ITypeSymbol[] return ValidateTypeParameterConstraints(methodTypeParams, typeArguments); } - private static bool ValidateTypeParameterConstraints(IEnumerable typeParams, ITypeSymbol[] typeArguments) + private static bool ValidateTypeParameterConstraints(ImmutableArray typeParams, ITypeSymbol[] typeArguments) { - var typeParamsList = typeParams.ToList(); - var typeParamsArray = typeParamsList.ToImmutableArray(); - - for (var i = 0; i < typeParamsList.Count; i++) + for (var i = 0; i < typeParams.Length; i++) { - var typeParam = typeParamsList[i]; + var typeParam = typeParams[i]; var typeArg = typeArguments[i]; // Check struct constraint @@ -4477,7 +4357,7 @@ private static bool ValidateTypeParameterConstraints(IEnumerable"); writer.AppendLine("["); writer.Indent(); - testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, filteredAttributes.ToImmutableArray()); + testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, filteredAttributes); writer.Unindent(); writer.AppendLine("],"); @@ -4799,7 +4679,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( // Filter data sources based on the specific attribute List methodDataSources; - List classDataSources; + AttributeData[] classDataSources; if (specificArgumentsAttribute != null) { @@ -4813,14 +4693,14 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( if (testMethod is { IsGenericType: true, IsGenericMethod: true }) { var additionalMethodDataSources = methodSymbol.GetAttributes() - .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute" && !AreSameAttribute(a, specificArgumentsAttribute)) - .ToList(); + .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute" && + !AreSameAttribute(a, specificArgumentsAttribute)); methodDataSources.AddRange(additionalMethodDataSources); } classDataSources = typeSymbol.GetAttributesIncludingBaseTypes() .Where(a => AreSameAttribute(a, specificArgumentsAttribute)) - .ToList(); + .ToArray(); } else { @@ -4831,7 +4711,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( classDataSources = typeSymbol.GetAttributesIncludingBaseTypes() .Where(a => DataSourceAttributeHelper.IsDataSourceAttribute(a.AttributeClass)) - .ToList(); + .ToArray(); } // Generate method data sources @@ -4855,7 +4735,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( } // Generate class data sources - if (classDataSources.Count == 0) + if (classDataSources.Length == 0) { writer.AppendLine("ClassDataSources = global::System.Array.Empty(),"); } @@ -5265,4 +5145,3 @@ public class InheritsTestsClassMetadata public GeneratorAttributeSyntaxContext Context { get; init; } public required CompilationContext CompilationContext { get; init; } } - diff --git a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs index e5332c94c2..fd2cc9e652 100644 --- a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs +++ b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs @@ -108,7 +108,7 @@ private static void WriteClassMetadataGetOrAdd(ICodeWriter writer, INamedTypeSym if (constructor != null && constructorParams.Length > 0) { writer.Append("Parameters = "); - WriteParameterMetadataArrayForConstructor(writer, constructor, typeSymbol); + WriteParameterMetadataArrayForConstructor(writer, constructor); writer.AppendLine(","); } else @@ -177,7 +177,7 @@ public static string GenerateClassMetadataGetOrAdd(INamedTypeSymbol typeSymbol, var constructorParams = constructor?.Parameters ?? ImmutableArray.Empty; if (constructor != null && constructorParams.Length > 0) { - writer.AppendLine($"Parameters = {GenerateParameterMetadataArrayForConstructor(constructor, typeSymbol, writer.IndentLevel)},"); + writer.AppendLine($"Parameters = {GenerateParameterMetadataArrayForConstructor(constructor, writer.IndentLevel)},"); } else { @@ -453,7 +453,7 @@ private static string GenerateParameterMetadataArrayForMethod(IMethodSymbol meth /// /// Writes an array of ParameterMetadata objects for constructor parameters with proper reflection info /// - private static void WriteParameterMetadataArrayForConstructor(ICodeWriter writer, IMethodSymbol constructor, INamedTypeSymbol containingType) + private static void WriteParameterMetadataArrayForConstructor(ICodeWriter writer, IMethodSymbol constructor) { if (constructor.Parameters.Length == 0) { @@ -488,7 +488,7 @@ private static void WriteParameterMetadataArrayForConstructor(ICodeWriter writer /// /// Generates an array of ParameterMetadata objects for constructor parameters with proper reflection info /// - private static string GenerateParameterMetadataArrayForConstructor(IMethodSymbol constructor, INamedTypeSymbol containingType, int currentIndentLevel = 0) + private static string GenerateParameterMetadataArrayForConstructor(IMethodSymbol constructor, int currentIndentLevel = 0) { if (constructor.Parameters.Length == 0) {