Skip to content

Commit

Permalink
feat: use private static fields to store constant typeParameters
Browse files Browse the repository at this point in the history
…where possible (#1606)
  • Loading branch information
TimothyMakkison authored Dec 2, 2023
1 parent 52151a2 commit 4055e7a
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 66 deletions.
109 changes: 91 additions & 18 deletions InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public class InterfaceStubGeneratorV2 : IIncrementalGenerator
public class InterfaceStubGenerator : ISourceGenerator
#endif
{
private const string TypeParameterVariableName = "_typeParameters";

#pragma warning disable RS2008 // Enable analyzer release tracking
static readonly DiagnosticDescriptor InvalidRefitMember =
new(
Expand Down Expand Up @@ -396,15 +398,17 @@ partial class {ns}{classDeclaration}
.Cast<IMethodSymbol>()
.ToList();

var memberNames = new HashSet<string>(interfaceSymbol.GetMembers().Select(x => x.Name));

// Handle Refit Methods
foreach (var method in refitMethods)
{
ProcessRefitMethod(source, method, true);
ProcessRefitMethod(source, method, true, memberNames);
}

foreach (var method in refitMethods.Concat(derivedRefitMethods))
{
ProcessRefitMethod(source, method, false);
ProcessRefitMethod(source, method, false, memberNames);
}

// Handle non-refit Methods that aren't static or properties or have a method body
Expand Down Expand Up @@ -445,8 +449,20 @@ partial class {ns}{classDeclaration}
/// <param name="source"></param>
/// <param name="methodSymbol"></param>
/// <param name="isTopLevel">True if directly from the type we're generating for, false for methods found on base interfaces</param>
void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool isTopLevel)
/// <param name="memberNames">Contains the unique member names in the interface scope.</param>
void ProcessRefitMethod(
StringBuilder source,
IMethodSymbol methodSymbol,
bool isTopLevel,
HashSet<string> memberNames
)
{
var parameterTypesExpression = GenerateTypeParameterExpression(
source,
methodSymbol,
memberNames
);

var returnType = methodSymbol.ReturnType.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat
);
Expand All @@ -466,15 +482,6 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
argList.Add($"@{param.MetadataName}");
}

// List of types.
var typeList = new List<string>();
foreach (var param in methodSymbol.Parameters)
{
typeList.Add(
$"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
);
}

// List of generic arguments
var genericList = new List<string>();
foreach (var typeParam in methodSymbol.TypeParameters)
Expand All @@ -489,11 +496,6 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
? "global::System.Array.Empty<object>()"
: $"new object[] {{ {string.Join(", ", argList)} }}";

var parameterTypesArrayString =
typeList.Count == 0
? "global::System.Array.Empty<global::System.Type>()"
: $"new global::System.Type[] {{ {string.Join(", ", typeList)} }}";

var genericString =
genericList.Count > 0
? $", new global::System.Type[] {{ {string.Join(", ", genericList)} }}"
Expand All @@ -502,7 +504,7 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
source.Append(
@$"
var ______arguments = {argumentsArrayString};
var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesArrayString}{genericString} );
var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesExpression}{genericString} );
try
{{
{@return}({returnType})______func(this.Client, ______arguments){configureAwait};
Expand Down Expand Up @@ -628,6 +630,63 @@ IMethodSymbol methodSymbol
}
}

static string GenerateTypeParameterExpression(
StringBuilder source,
IMethodSymbol methodSymbol,
HashSet<string> memberNames
)
{
// use Array.Empty if method has no parameters.
if (methodSymbol.Parameters.Length == 0)
return "global::System.Array.Empty<global::System.Type>()";

// if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls.
if (methodSymbol.Parameters.Any(x => ContainsTypeParameter(x.Type)))
{
var typeEnumerable = methodSymbol.Parameters.Select(
param =>
$"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
);
return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}";
}

// find a name and generate field declaration.
var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames);
var types = string.Join(
", ",
methodSymbol.Parameters.Select(
x =>
$"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
)
);
source.Append(
$$"""
private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} };
"""
);

return typeParameterFieldName;

static bool ContainsTypeParameter(ITypeSymbol symbol)
{
if (symbol is ITypeParameterSymbol)
return true;

if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType)
return false;

foreach (var typeArg in namedType.TypeArguments)
{
if (ContainsTypeParameter(typeArg))
return true;
}

return false;
}
}

void WriteMethodOpening(
StringBuilder source,
IMethodSymbol methodSymbol,
Expand Down Expand Up @@ -680,6 +739,20 @@ void WriteMethodOpening(

void WriteMethodClosing(StringBuilder source) => source.Append(@" }");

static string UniqueName(string name, HashSet<string> methodNames)
{
var candidateName = name;
var counter = 0;
while (methodNames.Contains(candidateName))
{
candidateName = $"{name}{counter}";
counter++;
}

methodNames.Add(candidateName);
return candidateName;
}

bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttibute)
{
return methodSymbol
Expand Down
Loading

0 comments on commit 4055e7a

Please sign in to comment.