diff --git a/TUnit.Core.SourceGenerator/Generators/AotConverterGenerator.cs b/TUnit.Core.SourceGenerator/Generators/AotConverterGenerator.cs index 864ffd7f15..20971ca516 100644 --- a/TUnit.Core.SourceGenerator/Generators/AotConverterGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/AotConverterGenerator.cs @@ -1,15 +1,17 @@ -using System; using System.Collections.Immutable; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using TUnit.Core.SourceGenerator.Extensions; +using TUnit.Core.SourceGenerator.Models; namespace TUnit.Core.SourceGenerator.Generators; [Generator] public class AotConverterGenerator : IIncrementalGenerator { + public static string ParseAotConverter = "ParseCompilationMetadata"; + public void Initialize(IncrementalGeneratorInitializationContext context) { var enabledProvider = context.AnalyzerConfigOptionsProvider @@ -26,7 +28,39 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var conversionInfos = new List(); ScanTestParameters(compilation, conversionInfos, ct); - return conversionInfos.ToImmutableArray(); + + // Deduplicate conversions based on source and target types + var seenConversions = new HashSet<(ITypeSymbol Source, ITypeSymbol Target)>( + new TypePairEqualityComparer()); + var uniqueConversions = new List(); + + foreach (var conversion in conversionInfos) + { + if (conversion == null) + { + continue; + } + + var key = (conversion.SourceType, conversion.TargetType); + if (seenConversions.Add(key)) + { + uniqueConversions.Add(conversion); + } + } + + return uniqueConversions.Select(c => + { + var sourceType = ToTypeMetadata(c.SourceType); + var targetType = ToTypeMetadata(c.TargetType); + + return new ConversionMetadata() + { + SourceType = sourceType, + TargetType = targetType, + TypesAreDifferent = !SymbolEqualityComparer.Default.Equals(c.SourceType, c.TargetType), + IsImplicit = c.IsImplicit, + }; + }).ToEquatableArray(); } catch (NullReferenceException ex) { @@ -34,7 +68,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) throw new InvalidOperationException($"NullReferenceException in ScanTestParameters: {ex.Message}\nStack: {stackTrace}", ex); } }) - .Combine(enabledProvider); + .Combine(enabledProvider) + .WithTrackingName(ParseAotConverter); context.RegisterSourceOutput(allTypes, (spc, data) => { @@ -74,60 +109,56 @@ private void ScanTestParameters(Compilation compilation, List co var semanticModel = compilation.GetSemanticModel(tree); var root = tree.GetRoot(); - var methods = root.DescendantNodes() - .OfType(); - - foreach (var method in methods) + foreach(var nodes in root.DescendantNodes()) { - var methodSymbol = semanticModel.GetDeclaredSymbol(method); - if (methodSymbol == null) - { - continue; - } - - if (!IsTestMethod(methodSymbol)) + if(nodes is MethodDeclarationSyntax method) { - continue; - } - - foreach (var parameter in methodSymbol.Parameters) - { - typesToScan.Add(parameter.Type); - ScanAttributesForTypes(parameter.GetAttributes(), typesToScan); - } + var methodSymbol = semanticModel.GetDeclaredSymbol(method); + if (methodSymbol == null) + { + continue; + } - ScanAttributesForTypes(methodSymbol.GetAttributes(), typesToScan); - } + if (!IsTestMethod(methodSymbol)) + { + continue; + } - var classes = root.DescendantNodes() - .OfType(); + foreach (var parameter in methodSymbol.Parameters) + { + typesToScan.Add(parameter.Type); + ScanAttributesForTypes(parameter.GetAttributes(), typesToScan); + } - foreach (var classDecl in classes) - { - var classSymbol = semanticModel.GetDeclaredSymbol(classDecl); - if (classSymbol == null) - { - continue; + ScanAttributesForTypes(methodSymbol.GetAttributes(), typesToScan); } - - if (!IsTestClass(classSymbol)) + else if (nodes is ClassDeclarationSyntax classDecl) { - continue; - } - - ScanAttributesForTypes(classSymbol.GetAttributes(), typesToScan); + var classSymbol = semanticModel.GetDeclaredSymbol(classDecl); + if (classSymbol == null) + { + continue; + } - foreach (var constructor in classSymbol.Constructors) - { - if (constructor.IsImplicitlyDeclared) + if (!IsTestClass(classSymbol)) { continue; } - foreach (var parameter in constructor.Parameters) + ScanAttributesForTypes(classSymbol.GetAttributes(), typesToScan); + + foreach (var constructor in classSymbol.Constructors) { - typesToScan.Add(parameter.Type); - ScanAttributesForTypes(parameter.GetAttributes(), typesToScan); + if (constructor.IsImplicitlyDeclared) + { + continue; + } + + foreach (var parameter in constructor.Parameters) + { + typesToScan.Add(parameter.Type); + ScanAttributesForTypes(parameter.GetAttributes(), typesToScan); + } } } } @@ -140,7 +171,7 @@ private void ScanTestParameters(Compilation compilation, List co } } - private bool IsTestMethod(IMethodSymbol method) + private static bool IsTestMethod(IMethodSymbol method) { return method.GetAttributes().Any(attr => { @@ -445,11 +476,9 @@ private bool IsAccessibleType(ITypeSymbol type, Compilation compilation) return new ConversionInfo { - ContainingType = containingType, SourceType = sourceType, TargetType = targetType, IsImplicit = isImplicit, - MethodSymbol = methodSymbol }; } @@ -484,12 +513,12 @@ private bool TypeContainsGenericTypeParameters(ITypeSymbol type) return false; } - private void GenerateConverters(SourceProductionContext context, ImmutableArray conversions) + private void GenerateConverters(SourceProductionContext context, EquatableArray conversions) { var writer = new CodeWriter(); writer.AppendLine("#nullable enable"); - if (conversions.IsEmpty) + if (conversions.Length == 0) { writer.AppendLine(); writer.AppendLine("// No conversion operators found"); @@ -497,25 +526,6 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< return; } - // Deduplicate conversions based on source and target types - var seenConversions = new HashSet<(ITypeSymbol Source, ITypeSymbol Target)>( - new TypePairEqualityComparer()); - var uniqueConversions = new List(); - - foreach (var conversion in conversions) - { - if (conversion == null) - { - continue; - } - - var key = (conversion.SourceType, conversion.TargetType); - if (seenConversions.Add(key)) - { - uniqueConversions.Add(conversion); - } - } - writer.AppendLine(); writer.AppendLine("using System;"); writer.AppendLine("using TUnit.Core.Converters;"); @@ -526,7 +536,7 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< var converterIndex = 0; var registrations = new List(); - foreach (var conversion in uniqueConversions) + foreach (var conversion in conversions) { try { @@ -541,8 +551,8 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< defaultSeverity: DiagnosticSeverity.Warning, isEnabledByDefault: true), Location.None, - conversion.SourceType?.ToDisplayString() ?? "null", - conversion.TargetType?.ToDisplayString() ?? "null")); + conversion.SourceType?.DisplayString ?? "null", + conversion.TargetType?.DisplayString ?? "null")); continue; } } @@ -562,8 +572,8 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< } var converterClassName = $"AotConverter_{converterIndex++}"; - var sourceTypeName = conversion.SourceType.GloballyQualified(); - var targetTypeName = conversion.TargetType.GloballyQualified(); + var sourceTypeName = conversion.SourceType.GloballyQualified; + var targetTypeName = conversion.TargetType.GloballyQualified; writer.AppendLine($"internal sealed class {converterClassName} : IAotConverter"); writer.AppendLine("{"); @@ -583,14 +593,7 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< var sourceType = conversion.SourceType; var targetType = conversion.TargetType; - ITypeSymbol typeForTargetPattern = targetType; - if (targetType is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableTargetType) - { - typeForTargetPattern = nullableTargetType.TypeArguments[0]; - } - var targetPatternTypeName = typeForTargetPattern.GloballyQualified(); - - writer.AppendLine($"if (value is {targetPatternTypeName} targetTypedValue)"); + writer.AppendLine($"if (value is {targetType.PatternTypeName} targetTypedValue)"); writer.AppendLine("{"); writer.Indent(); writer.AppendLine("return targetTypedValue;"); @@ -599,19 +602,10 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< // 2. If types are different, generate the fallback check for the source type. // This handles cases that require an implicit conversion. - if (!SymbolEqualityComparer.Default.Equals(sourceType, targetType)) + if (conversion.TypesAreDifferent) { - // For pattern matching, we must unwrap nullable types (C# language requirement - CS8116) - ITypeSymbol typeForSourcePattern = sourceType; - if (sourceType is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableSourceType) - { - typeForSourcePattern = nullableSourceType.TypeArguments[0]; - } - - var sourcePatternTypeName = typeForSourcePattern.GloballyQualified(); - writer.AppendLine(); - writer.AppendLine($"if (value is {sourcePatternTypeName} sourceTypedValue)"); + writer.AppendLine($"if (value is {sourceType.PatternTypeName} sourceTypedValue)"); writer.AppendLine("{"); writer.Indent(); // For explicit conversions, we need to use an explicit cast @@ -666,13 +660,34 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray< context.AddSource("AotConverters.g.cs", writer.ToString()); } + private static TypeMetadata ToTypeMetadata(ITypeSymbol type) + { + var globallyQualified = type.GloballyQualified(); + + // For pattern matching, we must unwrap nullable types (C# language requirement - CS8116) + string patternTypeName = globallyQualified; + if (type is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableSourceType) + { + patternTypeName = nullableSourceType.TypeArguments[0].GloballyQualified(); + } + return new TypeMetadata(globallyQualified, type.ToDisplayString(), patternTypeName); + } + + public record TypeMetadata(string GloballyQualified, string DisplayString, string PatternTypeName); + + public record ConversionMetadata + { + public required TypeMetadata SourceType { get; init; } + public required TypeMetadata TargetType { get; init; } + public required bool TypesAreDifferent { get; init; } + public required bool IsImplicit { get; init; } + } + private class ConversionInfo { - public required INamedTypeSymbol ContainingType { get; init; } public required ITypeSymbol SourceType { get; init; } public required ITypeSymbol TargetType { get; init; } public required bool IsImplicit { get; init; } - public required IMethodSymbol MethodSymbol { get; init; } } private class TypePairEqualityComparer : IEqualityComparer<(ITypeSymbol Source, ITypeSymbol Target)> diff --git a/TUnit.SourceGenerator.IncrementalTests/AotConverterGeneratorIncrementalTests.cs b/TUnit.SourceGenerator.IncrementalTests/AotConverterGeneratorIncrementalTests.cs new file mode 100644 index 0000000000..d70cf5c403 --- /dev/null +++ b/TUnit.SourceGenerator.IncrementalTests/AotConverterGeneratorIncrementalTests.cs @@ -0,0 +1,118 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using TUnit.Core.SourceGenerator.Generators; +using TUnit.Core.SourceGenerator.Models; + +namespace TUnit.Assertions.SourceGenerator.IncrementalTests; + +public class AotConverterGeneratorIncrementalTests +{ + private const string DefaultConverter = + """ + using global::TUnit.Core; + + #nullable enabled + public record Foo + { + public static implicit operator Foo((int Value1, int Value2) tuple) => new(); + } + + public class Tests + { + [Test] + [MethodDataSource(nameof(Data))] + public void Test1(Foo data) + { + } + + public static IEnumerable Data() => [new()]; + } + """; + + private const string SecondConverter = + """ + using global::TUnit.Core; + + #nullable enabled + public record FooBar + { + public static implicit operator FooBar((int Value1, int Value2) tuple) => new(); + } + + public class Tests1 + { + [Test] + [MethodDataSource(nameof(Data))] + public void Test1(FooBar data) + { + } + + public static IEnumerable Data() => [new()]; + } + """; + + [Fact] + public void AddUnrelatedType_MethodShouldNotRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultConverter, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New, 1); + + var compilation2 = compilation1.AddSyntaxTrees(CSharpSyntaxTree.ParseText("struct MyValue {}")); + var driver2 = driver1.RunGenerators(compilation2); + AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Cached, 1); + } + + [Fact] + public void AddNewConverterShouldRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultConverter, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New, 1); + + var compilation2 = compilation1.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SecondConverter)); + var driver2 = driver1.RunGenerators(compilation2); + AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Modified, 2); + } + + [Fact] + public void ModifyOperatorShouldRegenerate() + { + var syntaxTree = CSharpSyntaxTree.ParseText(DefaultConverter, CSharpParseOptions.Default); + var compilation1 = Fixture.CreateLibrary(syntaxTree); + + var driver1 = TestHelper.GenerateTracked(compilation1); + AssertRunReasons(driver1, IncrementalGeneratorRunReasons.New, 1); + + var compilation2 = TestHelper.ReplaceTypeDeclaration(compilation1, "Foo", + """ + public record Foo + { + public static explicit operator Foo((int Value1, int Value2) tuple) => new(); + } + """ + ); + + var driver2 = driver1.RunGenerators(compilation2); + AssertRunReasons(driver2, IncrementalGeneratorRunReasons.Modified, 1); + } + + private static void AssertRunReasons( + GeneratorDriver driver, + IncrementalGeneratorRunReasons reasons, + int conversionMetadataLength, + int outputIndex = 0 + ) + { + var runResult = driver.GetRunResult().Results[0]; + var runValue = runResult.TrackedSteps[AotConverterGenerator.ParseAotConverter][0].Outputs[0].Value; + var runState = (ValueTuple, bool>)runValue; + Assert.That(runState.Item1.Length == conversionMetadataLength); + + TestHelper.AssertRunReason(runResult, AotConverterGenerator.ParseAotConverter, reasons.BuildStep, outputIndex); + } +} diff --git a/TUnit.SourceGenerator.IncrementalTests/Fixture.cs b/TUnit.SourceGenerator.IncrementalTests/Fixture.cs index 8dad9a415e..563750461f 100644 --- a/TUnit.SourceGenerator.IncrementalTests/Fixture.cs +++ b/TUnit.SourceGenerator.IncrementalTests/Fixture.cs @@ -3,6 +3,7 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using TUnit.Assertions.Attributes; +using TUnit.Core; namespace TUnit.Assertions.SourceGenerator.IncrementalTests; @@ -15,6 +16,7 @@ public static class Fixture typeof(GenerateAssertionAttribute).Assembly, typeof(MulticastDelegate).Assembly, typeof(IServiceProvider).Assembly, + typeof(TestAttribute).Assembly, }; public static Assembly[] AssemblyReferencesForCodegen => diff --git a/TUnit.SourceGenerator.IncrementalTests/TUnit.SourceGenerator.IncrementalTests.csproj b/TUnit.SourceGenerator.IncrementalTests/TUnit.SourceGenerator.IncrementalTests.csproj index 41b69e16eb..05e2791fbb 100644 --- a/TUnit.SourceGenerator.IncrementalTests/TUnit.SourceGenerator.IncrementalTests.csproj +++ b/TUnit.SourceGenerator.IncrementalTests/TUnit.SourceGenerator.IncrementalTests.csproj @@ -22,7 +22,9 @@ - - + + + + diff --git a/TUnit.SourceGenerator.IncrementalTests/TestHelper.cs b/TUnit.SourceGenerator.IncrementalTests/TestHelper.cs index 7ea01a1997..4490b79a1b 100644 --- a/TUnit.SourceGenerator.IncrementalTests/TestHelper.cs +++ b/TUnit.SourceGenerator.IncrementalTests/TestHelper.cs @@ -23,9 +23,9 @@ [ generator.AsSourceGenerator() ], return driver.RunGenerators(compilation); } - internal static CSharpCompilation ReplaceMemberDeclaration( + internal static CSharpCompilation ReplaceTypeDeclaration( CSharpCompilation compilation, - string memberName, + string typeName, string newMember ) { @@ -34,7 +34,7 @@ string newMember .GetCompilationUnitRoot() .DescendantNodes() .OfType() - .Single(x => x.Identifier.Text == memberName); + .Single(x => x.Identifier.Text == typeName); var updatedMemberDeclaration = SyntaxFactory.ParseMemberDeclaration(newMember)!; var newRoot = syntaxTree.GetCompilationUnitRoot().ReplaceNode(memberDeclaration, updatedMemberDeclaration);