diff --git a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs index ddb4bf6af..535a1d6a9 100644 --- a/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs +++ b/InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs @@ -24,7 +24,6 @@ public class InterfaceStubGenerator : ISourceGenerator #endif { private const string TypeParameterVariableName = "_typeParameters"; - private const string GenericTypeVariableName = "_genericTypes"; #pragma warning disable RS2008 // Enable analyzer release tracking static readonly DiagnosticDescriptor InvalidRefitMember = new( @@ -388,8 +387,7 @@ partial class {ns}{classDeclaration} /// Contains the unique member names in the interface scope. void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool isTopLevel, HashSet memberNames) { - var (typeParameterName, nullInit, genericTypeName) = - WriteMethodStaticFields(source, methodSymbol, memberNames); + var parameterTypesExpression = GenerateTypeParameterExpression(source, methodSymbol, memberNames); var returnType = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var (isAsync, @return, configureAwait) = methodSymbol.ReturnType.MetadataName switch @@ -408,35 +406,22 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i argList.Add($"@{param.MetadataName}"); } + // List of generic arguments + var genericList = new List(); + foreach (var typeParam in methodSymbol.TypeParameters) + { + genericList.Add($"typeof({typeParam.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"); + } + var argumentsArrayString = argList.Count == 0 ? "global::System.Array.Empty()" : $"new object[] {{ {string.Join(", ", argList)} }}"; - var parameterTypesVariable = typeParameterName ?? "global::System.Array.Empty()"; - var genericTypeVariable = genericTypeName != null ? $", {genericTypeName}" : string.Empty; - - // use noll coalescing assignment to initialize the type parameter array - if (nullInit) - { - // List of types. - var typeList = methodSymbol.Parameters.Select(param => $"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})").ToList(); - - source.Append(@$" - {typeParameterName} ??= new global::System.Type[] {{ {string.Join(", ", typeList)} }};"); - } - // use noll coalescing assignment to initialize the generic argument type array - if (genericTypeName != null) - { - // List of generic arguments - var genericList = methodSymbol.TypeParameters.Select(typeParam => $"typeof({typeParam.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})").ToList(); - - source.Append(@$" - {genericTypeName} ??= new global::System.Type[] {{ {string.Join(", ", genericList)} }};"); - } + var genericString = genericList.Count > 0 ? $", new global::System.Type[] {{ {string.Join(", ", genericList)} }}" : string.Empty; source.Append(@$" var ______arguments = {argumentsArrayString}; - var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesVariable}{genericTypeVariable} ); + var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesExpression}{genericString} ); try {{ {@return}({returnType})______func(this.Client, ______arguments){configureAwait}; @@ -533,55 +518,47 @@ void ProcessNonRefitMethod(TContext context, Action memberNames) + static string GenerateTypeParameterExpression(StringBuilder source, IMethodSymbol methodSymbol, HashSet memberNames) { - string? typeParamsName; - var nullInit = false; - string? genericTypeName = null; + // use Array.Empty if method has no parameters. + if (methodSymbol.Parameters.Length == 0) + return "global::System.Array.Empty()"; - if (symbol.Parameters.Length == 0) + // if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls. + if (methodSymbol.Parameters.Any(x => ContainsTypeParameter(x.Type))) { - typeParamsName = null; + var typeEnumerable = methodSymbol.Parameters.Select(param => $"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"); + return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}"; } - else - { - typeParamsName = UniqueName(TypeParameterVariableName, memberNames); - nullInit = symbol.Parameters.Any(x => x.Type.Kind == SymbolKind.TypeParameter || x.Type is INamedTypeSymbol { IsGenericType: true }); - if (nullInit) - { - source.Append( - $""" + // find a name and generate field declaration. + var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames); + var types = string.Join(", ", methodSymbol.Parameters.Select(x => $"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})")); + source.Append( + $$""" - private static global::System.Type[]? {typeParamsName}; - """); - } - else - { - var types = string.Join(", ", symbol.Parameters.Select(x => $"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})")); - source.Append( - $$""" + private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} }; + """); + return typeParameterFieldName; - private static readonly global::System.Type[] {{typeParamsName}} = new global::System.Type[] {{{types}} }; - """); - } - } - - if (symbol.TypeParameters.Length > 0) + static bool ContainsTypeParameter(ITypeSymbol symbol) { - genericTypeName = UniqueName(GenericTypeVariableName, memberNames); + if (symbol is ITypeParameterSymbol) + return true; - source.Append( - $""" + if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType) + return false; + foreach (var typeArg in namedType.TypeArguments) + { + if (ContainsTypeParameter(typeArg)) + return true; + } - private static global::System.Type[]? {genericTypeName}; - """); + return false; } - - return (typeParamsName, nullInit, genericTypeName); } void WriteMethodOpening(StringBuilder source, IMethodSymbol methodSymbol, bool isExplicitInterface, bool isAsync = false)