Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 109 additions & 94 deletions TUnit.Core.SourceGenerator/Generators/AotConverterGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
using System;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using TUnit.Core.SourceGenerator.Extensions;
using TUnit.Core.SourceGenerator.Models;

namespace TUnit.Core.SourceGenerator.Generators;

[Generator]
public class AotConverterGenerator : IIncrementalGenerator
{
public static string ParseAotConverter = "ParseCompilationMetadata";

public void Initialize(IncrementalGeneratorInitializationContext context)
{
var enabledProvider = context.AnalyzerConfigOptionsProvider
Expand All @@ -26,15 +28,48 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var conversionInfos = new List<ConversionInfo>();
ScanTestParameters(compilation, conversionInfos, ct);
return conversionInfos.ToImmutableArray();

// Deduplicate conversions based on source and target types
var seenConversions = new HashSet<(ITypeSymbol Source, ITypeSymbol Target)>(
new TypePairEqualityComparer());
var uniqueConversions = new List<ConversionInfo>();

foreach (var conversion in conversionInfos)
{
if (conversion == null)
{
continue;
}

var key = (conversion.SourceType, conversion.TargetType);
if (seenConversions.Add(key))
{
uniqueConversions.Add(conversion);
}
}

return uniqueConversions.Select(c =>
{
var sourceType = ToTypeMetadata(c.SourceType);
var targetType = ToTypeMetadata(c.TargetType);

return new ConversionMetadata()
{
SourceType = sourceType,
TargetType = targetType,
TypesAreDifferent = !SymbolEqualityComparer.Default.Equals(c.SourceType, c.TargetType),
IsImplicit = c.IsImplicit,
};
}).ToEquatableArray();
}
catch (NullReferenceException ex)
{
var stackTrace = ex.StackTrace ?? "No stack trace";
throw new InvalidOperationException($"NullReferenceException in ScanTestParameters: {ex.Message}\nStack: {stackTrace}", ex);
}
})
.Combine(enabledProvider);
.Combine(enabledProvider)
.WithTrackingName(ParseAotConverter);

context.RegisterSourceOutput(allTypes, (spc, data) =>
{
Expand Down Expand Up @@ -74,60 +109,56 @@ private void ScanTestParameters(Compilation compilation, List<ConversionInfo> co
var semanticModel = compilation.GetSemanticModel(tree);
var root = tree.GetRoot();

var methods = root.DescendantNodes()
.OfType<MethodDeclarationSyntax>();

foreach (var method in methods)
foreach(var nodes in root.DescendantNodes())
{
var methodSymbol = semanticModel.GetDeclaredSymbol(method);
if (methodSymbol == null)
{
continue;
}

if (!IsTestMethod(methodSymbol))
if(nodes is MethodDeclarationSyntax method)
{
continue;
}

foreach (var parameter in methodSymbol.Parameters)
{
typesToScan.Add(parameter.Type);
ScanAttributesForTypes(parameter.GetAttributes(), typesToScan);
}
var methodSymbol = semanticModel.GetDeclaredSymbol(method);
if (methodSymbol == null)
{
continue;
}

ScanAttributesForTypes(methodSymbol.GetAttributes(), typesToScan);
}
if (!IsTestMethod(methodSymbol))
{
continue;
}

var classes = root.DescendantNodes()
.OfType<ClassDeclarationSyntax>();
foreach (var parameter in methodSymbol.Parameters)
{
typesToScan.Add(parameter.Type);
ScanAttributesForTypes(parameter.GetAttributes(), typesToScan);
}

foreach (var classDecl in classes)
{
var classSymbol = semanticModel.GetDeclaredSymbol(classDecl);
if (classSymbol == null)
{
continue;
ScanAttributesForTypes(methodSymbol.GetAttributes(), typesToScan);
}

if (!IsTestClass(classSymbol))
else if (nodes is ClassDeclarationSyntax classDecl)
{
continue;
}

ScanAttributesForTypes(classSymbol.GetAttributes(), typesToScan);
var classSymbol = semanticModel.GetDeclaredSymbol(classDecl);
if (classSymbol == null)
{
continue;
}

foreach (var constructor in classSymbol.Constructors)
{
if (constructor.IsImplicitlyDeclared)
if (!IsTestClass(classSymbol))
{
continue;
}

foreach (var parameter in constructor.Parameters)
ScanAttributesForTypes(classSymbol.GetAttributes(), typesToScan);

foreach (var constructor in classSymbol.Constructors)
{
typesToScan.Add(parameter.Type);
ScanAttributesForTypes(parameter.GetAttributes(), typesToScan);
if (constructor.IsImplicitlyDeclared)
{
continue;
}

foreach (var parameter in constructor.Parameters)
{
typesToScan.Add(parameter.Type);
ScanAttributesForTypes(parameter.GetAttributes(), typesToScan);
}
}
}
}
Expand All @@ -140,7 +171,7 @@ private void ScanTestParameters(Compilation compilation, List<ConversionInfo> co
}
}

private bool IsTestMethod(IMethodSymbol method)
private static bool IsTestMethod(IMethodSymbol method)
{
return method.GetAttributes().Any(attr =>
{
Expand Down Expand Up @@ -445,11 +476,9 @@ private bool IsAccessibleType(ITypeSymbol type, Compilation compilation)

return new ConversionInfo
{
ContainingType = containingType,
SourceType = sourceType,
TargetType = targetType,
IsImplicit = isImplicit,
MethodSymbol = methodSymbol
};
}

Expand Down Expand Up @@ -484,38 +513,19 @@ private bool TypeContainsGenericTypeParameters(ITypeSymbol type)
return false;
}

private void GenerateConverters(SourceProductionContext context, ImmutableArray<ConversionInfo> conversions)
private void GenerateConverters(SourceProductionContext context, EquatableArray<ConversionMetadata> conversions)
{
var writer = new CodeWriter();
writer.AppendLine("#nullable enable");

if (conversions.IsEmpty)
if (conversions.Length == 0)
{
writer.AppendLine();
writer.AppendLine("// No conversion operators found");
context.AddSource("AotConverters.g.cs", writer.ToString());
return;
}

// Deduplicate conversions based on source and target types
var seenConversions = new HashSet<(ITypeSymbol Source, ITypeSymbol Target)>(
new TypePairEqualityComparer());
var uniqueConversions = new List<ConversionInfo>();

foreach (var conversion in conversions)
{
if (conversion == null)
{
continue;
}

var key = (conversion.SourceType, conversion.TargetType);
if (seenConversions.Add(key))
{
uniqueConversions.Add(conversion);
}
}

writer.AppendLine();
writer.AppendLine("using System;");
writer.AppendLine("using TUnit.Core.Converters;");
Expand All @@ -526,7 +536,7 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<
var converterIndex = 0;
var registrations = new List<string>();

foreach (var conversion in uniqueConversions)
foreach (var conversion in conversions)
{
try
{
Expand All @@ -541,8 +551,8 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true),
Location.None,
conversion.SourceType?.ToDisplayString() ?? "null",
conversion.TargetType?.ToDisplayString() ?? "null"));
conversion.SourceType?.DisplayString ?? "null",
conversion.TargetType?.DisplayString ?? "null"));
continue;
}
}
Expand All @@ -562,8 +572,8 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<
}

var converterClassName = $"AotConverter_{converterIndex++}";
var sourceTypeName = conversion.SourceType.GloballyQualified();
var targetTypeName = conversion.TargetType.GloballyQualified();
var sourceTypeName = conversion.SourceType.GloballyQualified;
var targetTypeName = conversion.TargetType.GloballyQualified;

writer.AppendLine($"internal sealed class {converterClassName} : IAotConverter");
writer.AppendLine("{");
Expand All @@ -583,14 +593,7 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<
var sourceType = conversion.SourceType;
var targetType = conversion.TargetType;

ITypeSymbol typeForTargetPattern = targetType;
if (targetType is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableTargetType)
{
typeForTargetPattern = nullableTargetType.TypeArguments[0];
}
var targetPatternTypeName = typeForTargetPattern.GloballyQualified();

writer.AppendLine($"if (value is {targetPatternTypeName} targetTypedValue)");
writer.AppendLine($"if (value is {targetType.PatternTypeName} targetTypedValue)");
writer.AppendLine("{");
writer.Indent();
writer.AppendLine("return targetTypedValue;");
Expand All @@ -599,19 +602,10 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<

// 2. If types are different, generate the fallback check for the source type.
// This handles cases that require an implicit conversion.
if (!SymbolEqualityComparer.Default.Equals(sourceType, targetType))
if (conversion.TypesAreDifferent)
{
// For pattern matching, we must unwrap nullable types (C# language requirement - CS8116)
ITypeSymbol typeForSourcePattern = sourceType;
if (sourceType is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableSourceType)
{
typeForSourcePattern = nullableSourceType.TypeArguments[0];
}

var sourcePatternTypeName = typeForSourcePattern.GloballyQualified();

writer.AppendLine();
writer.AppendLine($"if (value is {sourcePatternTypeName} sourceTypedValue)");
writer.AppendLine($"if (value is {sourceType.PatternTypeName} sourceTypedValue)");
writer.AppendLine("{");
writer.Indent();
// For explicit conversions, we need to use an explicit cast
Expand Down Expand Up @@ -666,13 +660,34 @@ private void GenerateConverters(SourceProductionContext context, ImmutableArray<
context.AddSource("AotConverters.g.cs", writer.ToString());
}

private static TypeMetadata ToTypeMetadata(ITypeSymbol type)
{
var globallyQualified = type.GloballyQualified();

// For pattern matching, we must unwrap nullable types (C# language requirement - CS8116)
string patternTypeName = globallyQualified;
if (type is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T, TypeArguments.Length: > 0 } nullableSourceType)
{
patternTypeName = nullableSourceType.TypeArguments[0].GloballyQualified();
}
return new TypeMetadata(globallyQualified, type.ToDisplayString(), patternTypeName);
}

public record TypeMetadata(string GloballyQualified, string DisplayString, string PatternTypeName);

public record ConversionMetadata
{
public required TypeMetadata SourceType { get; init; }
public required TypeMetadata TargetType { get; init; }
public required bool TypesAreDifferent { get; init; }
public required bool IsImplicit { get; init; }
}

private class ConversionInfo
{
public required INamedTypeSymbol ContainingType { get; init; }
public required ITypeSymbol SourceType { get; init; }
public required ITypeSymbol TargetType { get; init; }
public required bool IsImplicit { get; init; }
public required IMethodSymbol MethodSymbol { get; init; }
}

private class TypePairEqualityComparer : IEqualityComparer<(ITypeSymbol Source, ITypeSymbol Target)>
Expand Down
Loading
Loading