From e1f48e368de177cab977799d03da105e68dc8b8a Mon Sep 17 00:00:00 2001 From: Timothy Makkison Date: Thu, 29 Jan 2026 15:56:36 +0000 Subject: [PATCH] perf: cache `GetAttributeObjectInitializer`, make `AttributeWriter` stateful --- .../CodeGenerators/Writers/AttributeWriter.cs | 29 +++++++--- .../Generators/TestMetadataGenerator.cs | 56 ++++++++++++------- .../Models/TestMethodMetadata.cs | 7 ++- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs index 99e73e8e65..052642e926 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/AttributeWriter.cs @@ -6,9 +6,11 @@ namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers; -public class AttributeWriter +public class AttributeWriter(Compilation compilation) { - public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation compilation, + private readonly Dictionary _attributeObjectInitializerCache = new(); + + public void WriteAttributes(ICodeWriter sourceCodeWriter, ImmutableArray attributeDatas) { var attributesToWrite = new List(); @@ -49,7 +51,7 @@ public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation com { var attributeData = attributesToWrite[index]; - WriteAttribute(sourceCodeWriter, compilation, attributeData); + WriteAttribute(sourceCodeWriter, attributeData); if (index != attributesToWrite.Count - 1) { @@ -58,8 +60,7 @@ public static void WriteAttributes(ICodeWriter sourceCodeWriter, Compilation com } } - public static void WriteAttribute(ICodeWriter sourceCodeWriter, Compilation compilation, - AttributeData attributeData) + public void WriteAttribute(ICodeWriter sourceCodeWriter, AttributeData attributeData) { if (attributeData.ApplicationSyntaxReference is null) { @@ -70,12 +71,23 @@ public static void WriteAttribute(ICodeWriter sourceCodeWriter, Compilation comp else { // For attributes from the current compilation, use the syntax-based approach - sourceCodeWriter.Append(GetAttributeObjectInitializer(compilation, attributeData)); + sourceCodeWriter.Append(GetAttributeObjectInitializer(attributeData)); + } + } + + public string GetAttributeObjectInitializer(AttributeData attributeData) + { + if (_attributeObjectInitializerCache.TryGetValue(attributeData, out var initializer)) + { + return initializer; } + + initializer = GetAttributeObjectInitializerInner(compilation, attributeData); + _attributeObjectInitializerCache.Add(attributeData, initializer); + return initializer; } - public static string GetAttributeObjectInitializer(Compilation compilation, - AttributeData attributeData) + private static string GetAttributeObjectInitializerInner(Compilation compilation, AttributeData attributeData) { var sourceCodeWriter = new CodeWriter("", includeHeader: false); @@ -123,7 +135,6 @@ public static string GetAttributeObjectInitializer(Compilation compilation, return sourceCodeWriter.ToString(); } - private static string FormatConstructorArgument(Compilation compilation, AttributeArgumentSyntax attributeArgumentSyntax) { if (attributeArgumentSyntax.NameColon is not null) diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 41ac1a6098..3e99cce971 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -28,11 +28,21 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return !string.Equals(value, "false", StringComparison.OrdinalIgnoreCase); }); + var compilationContext = context + .CompilationProvider + .Select(static (c, _) => + new CompilationContext( + (CSharpCompilation)c, + new AttributeWriter(c) + )); + var testMethodsProvider = context.SyntaxProvider .ForAttributeWithMetadataName( "TUnit.Core.TestAttribute", predicate: static (node, _) => node is MethodDeclarationSyntax, - transform: static (ctx, _) => GetTestMethodMetadata(ctx)) + transform: static (ctx, _) => ctx) + .Combine(compilationContext) + .Select(static (ctx, _) => GetTestMethodMetadata(ctx.Left, ctx.Right)) .Where(static m => m is not null) .Combine(enabledProvider); @@ -40,7 +50,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .ForAttributeWithMetadataName( "TUnit.Core.InheritsTestsAttribute", predicate: static (node, _) => node is ClassDeclarationSyntax, - transform: static (ctx, _) => GetInheritsTestsClassMetadata(ctx)) + transform: static (ctx, _) => ctx) + .Combine(compilationContext) + .Select(static (ctx, _) => GetInheritsTestsClassMetadata(ctx.Left, ctx.Right)) .Where(static m => m is not null) .Combine(enabledProvider); @@ -67,7 +79,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); } - private static InheritsTestsClassMetadata? GetInheritsTestsClassMetadata(GeneratorAttributeSyntaxContext context) + private static InheritsTestsClassMetadata? GetInheritsTestsClassMetadata(GeneratorAttributeSyntaxContext context, CompilationContext compilationContext) { var classSyntax = (ClassDeclarationSyntax)context.TargetNode; @@ -85,11 +97,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { TypeSymbol = classSymbol, ClassSyntax = classSyntax, - Context = context + Context = context, + CompilationContext = compilationContext }; } - private static TestMethodMetadata? GetTestMethodMetadata(GeneratorAttributeSyntaxContext context) + private static TestMethodMetadata? GetTestMethodMetadata(GeneratorAttributeSyntaxContext context, CompilationContext compilationContext) { var methodSyntax = (MethodDeclarationSyntax)context.TargetNode; var methodSymbol = context.TargetSymbol as IMethodSymbol; @@ -121,6 +134,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) LineNumber = lineNumber, TestAttribute = context.Attributes.First(), Context = context, + CompilationContext = compilationContext, MethodSyntax = methodSyntax, IsGenericType = isGenericType, IsGenericMethod = isGenericMethod, @@ -186,6 +200,7 @@ private static void GenerateInheritedTestSources(SourceProductionContext context LineNumber = lineNumber, TestAttribute = testAttribute, Context = classInfo.Context, // Use class context to access Compilation + CompilationContext = classInfo.CompilationContext, MethodSyntax = null, // No syntax for inherited methods IsGenericType = typeForMetadata.IsGenericType, IsGenericMethod = (concreteMethod ?? method).IsGenericMethod, @@ -458,7 +473,7 @@ private static void GenerateMetadata(CodeWriter writer, TestMethodMetadata testM .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) .ToImmutableArray(); - AttributeWriter.WriteAttributes(writer, compilation, attributes); + testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes); writer.Unindent(); writer.AppendLine("],"); @@ -504,7 +519,7 @@ private static void GenerateMetadataForConcreteInstantiation(CodeWriter writer, .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) .ToImmutableArray(); - AttributeWriter.WriteAttributes(writer, compilation, attributes); + testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes); writer.Unindent(); writer.AppendLine("],"); @@ -564,7 +579,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te foreach (var attr in methodDataSources) { - GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, typeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, typeSymbol); } writer.Unindent(); @@ -584,7 +599,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te foreach (var attr in classDataSources) { - GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, typeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, typeSymbol); } writer.Unindent(); @@ -595,7 +610,7 @@ private static void GenerateDataSources(CodeWriter writer, TestMethodMetadata te GeneratePropertyDataSources(writer, testMethod); } - private static void GenerateDataSourceAttribute(CodeWriter writer, Compilation compilation, AttributeData attr, IMethodSymbol methodSymbol, INamedTypeSymbol typeSymbol) + private static void GenerateDataSourceAttribute(CodeWriter writer, CompilationContext compilationContext, AttributeData attr, IMethodSymbol methodSymbol, INamedTypeSymbol typeSymbol) { var attrClass = attr.AttributeClass; if (attrClass == null) @@ -613,18 +628,18 @@ private static void GenerateDataSourceAttribute(CodeWriter writer, Compilation c { try { - GenerateArgumentsAttributeWithParameterTypes(writer, compilation, attr, methodSymbol); + GenerateArgumentsAttributeWithParameterTypes(writer, compilationContext.Compilation, attr, methodSymbol); } catch { // Fall back to default behavior if parameter type matching fails - AttributeWriter.WriteAttribute(writer, compilation, attr); + compilationContext.AttributeWriter.WriteAttribute(writer, attr); writer.AppendLine(","); } } else { - AttributeWriter.WriteAttribute(writer, compilation, attr); + compilationContext.AttributeWriter.WriteAttribute(writer, attr); writer.AppendLine(","); } } @@ -1535,7 +1550,7 @@ private static void GeneratePropertyDataSources(CodeWriter writer, TestMethodMet writer.AppendLine($"PropertyName = \"{property.Name}\","); writer.AppendLine($"PropertyType = typeof({property.Type.GloballyQualified()}),"); writer.Append("DataSource = "); - GenerateDataSourceAttribute(writer, compilation, dataSourceAttr, testMethod.MethodSymbol, typeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, dataSourceAttr, testMethod.MethodSymbol, typeSymbol); writer.Unindent(); writer.AppendLine("},"); } @@ -4805,7 +4820,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( writer.AppendLine("AttributeFactory = static () =>"); writer.AppendLine("["); writer.Indent(); - AttributeWriter.WriteAttributes(writer, compilation, filteredAttributes.ToImmutableArray()); + testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, filteredAttributes.ToImmutableArray()); writer.Unindent(); writer.AppendLine("],"); @@ -4866,7 +4881,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( foreach (var attr in methodDataSources) { - GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, concreteTypeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, concreteTypeSymbol); } writer.Unindent(); @@ -4886,7 +4901,7 @@ private static void GenerateConcreteMetadataWithFilteredDataSources( foreach (var attr in classDataSources) { - GenerateDataSourceAttribute(writer, compilation, attr, methodSymbol, concreteTypeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, attr, methodSymbol, concreteTypeSymbol); } writer.Unindent(); @@ -5133,7 +5148,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric( .Concat(testMethod.TypeSymbol.ContainingAssembly.GetAttributes()) .ToImmutableArray(); - AttributeWriter.WriteAttributes(writer, compilation, attributes); + testMethod.CompilationContext.AttributeWriter.WriteAttributes(writer, attributes); writer.Unindent(); writer.AppendLine("],"); @@ -5154,7 +5169,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric( writer.AppendLine("DataSources = new global::TUnit.Core.IDataSourceAttribute[]"); writer.AppendLine("{"); writer.Indent(); - GenerateDataSourceAttribute(writer, compilation, methodDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, methodDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol); writer.Unindent(); writer.AppendLine("},"); } @@ -5169,7 +5184,7 @@ private static void GenerateConcreteTestMetadataForNonGeneric( writer.AppendLine("ClassDataSources = new global::TUnit.Core.IDataSourceAttribute[]"); writer.AppendLine("{"); writer.Indent(); - GenerateDataSourceAttribute(writer, compilation, classDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol); + GenerateDataSourceAttribute(writer, testMethod.CompilationContext, classDataSourceAttribute, testMethod.MethodSymbol, testMethod.TypeSymbol); writer.Unindent(); writer.AppendLine("},"); } @@ -5282,5 +5297,6 @@ public class InheritsTestsClassMetadata public required INamedTypeSymbol TypeSymbol { get; init; } public required ClassDeclarationSyntax ClassSyntax { get; init; } public GeneratorAttributeSyntaxContext Context { get; init; } + public required CompilationContext CompilationContext { get; init; } } diff --git a/TUnit.Core.SourceGenerator/Models/TestMethodMetadata.cs b/TUnit.Core.SourceGenerator/Models/TestMethodMetadata.cs index e97d4fb74c..0c0cc6bfbc 100644 --- a/TUnit.Core.SourceGenerator/Models/TestMethodMetadata.cs +++ b/TUnit.Core.SourceGenerator/Models/TestMethodMetadata.cs @@ -1,9 +1,13 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using TUnit.Core.SourceGenerator.CodeGenerators.Writers; namespace TUnit.Core.SourceGenerator.Models; +public record CompilationContext(CSharpCompilation Compilation, AttributeWriter AttributeWriter); + /// /// Contains all the metadata about a test method discovered by the source generator. /// @@ -15,6 +19,7 @@ public class TestMethodMetadata : IEquatable public required int LineNumber { get; init; } public required AttributeData TestAttribute { get; init; } public GeneratorAttributeSyntaxContext? Context { get; init; } + public required CompilationContext CompilationContext { get; init; } public required MethodDeclarationSyntax? MethodSyntax { get; init; } public bool IsGenericType { get; init; } public bool IsGenericMethod { get; init; } @@ -23,7 +28,7 @@ public class TestMethodMetadata : IEquatable /// All attributes on the method, stored for later use during data combination generation /// public ImmutableArray MethodAttributes { get; init; } = ImmutableArray.Empty; - + /// /// The inheritance depth of this test method. /// 0 = method is declared directly in the test class