Skip to content

Commit

Permalink
Check for HttpMethodAttribute as a base type so custom attributes can…
Browse files Browse the repository at this point in the history
… work
  • Loading branch information
clairernovotny committed Jan 23, 2021
1 parent 47a28be commit 88a8573
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
45 changes: 45 additions & 0 deletions InterfaceStubGenerator.Core/ITypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

using Microsoft.CodeAnalysis;

namespace Refit.Generator
{
static class ITypeSymbolExtensions
{
public static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(this ITypeSymbol? type)
{
var current = type;
while (current != null)
{
yield return current;
current = current.BaseType;
}
}

// Determine if "type" inherits from "baseType", ignoring constructed types, optionally including interfaces,
// dealing only with original types.
public static bool InheritsFromOrEquals(
this ITypeSymbol type, ITypeSymbol baseType, bool includeInterfaces)
{
if (!includeInterfaces)
{
return InheritsFromOrEquals(type, baseType);
}

return type.GetBaseTypesAndThis().Concat(type.AllInterfaces).Any(t => t.Equals(baseType, SymbolEqualityComparer.Default));
}


// Determine if "type" inherits from "baseType", ignoring constructed types and interfaces, dealing
// only with original types.
public static bool InheritsFromOrEquals(
this ITypeSymbol type, ITypeSymbol baseType)
{
return type.GetBaseTypesAndThis().Any(t => t.Equals(baseType, SymbolEqualityComparer.Default));
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<AssemblyOriginatorKeyFile>..\buildtask.snk</AssemblyOriginatorKeyFile>
<SignAssembly>true</SignAssembly>
<IsRoslynComponent>true</IsRoslynComponent>
<CopyLocalLockFileAssemblies>true</CopyLocalLockFileAssemblies>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
Expand Down
45 changes: 24 additions & 21 deletions InterfaceStubGenerator.Core/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ public class InterfaceStubGenerator : ISourceGenerator
"Refit",
DiagnosticSeverity.Warning,
true);

static readonly DiagnosticDescriptor RefitNotReferenced = new(
"RF002",
"Refit must be referenced",
"Refit is not referenced. Add a reference to Refit.",
"Refit",
DiagnosticSeverity.Error,
true);
#pragma warning restore RS2008 // Enable analyzer release tracking

public void Execute(GeneratorExecutionContext context)
Expand Down Expand Up @@ -67,36 +75,31 @@ sealed class PreserveAttribute : Attribute

// we're going to create a new compilation that contains the attribute.
// TODO: we should allow source generators to provide source during initialize, so that this step isn't required.
var options = (context.Compilation as CSharpCompilation).SyntaxTrees[0].Options as CSharpParseOptions;
var options = (context.Compilation as CSharpCompilation)!.SyntaxTrees[0].Options as CSharpParseOptions;
var compilation = context.Compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(SourceText.From(attributeText, Encoding.UTF8), options));

// get the newly bound attribute
var preserveAttributeSymbol = compilation.GetTypeByMetadataName($"{refitInternalNamespace}.PreserveAttribute");
var disposableInterfaceSymbol = compilation.GetTypeByMetadataName("System.IDisposable");
var preserveAttributeSymbol = compilation.GetTypeByMetadataName($"{refitInternalNamespace}.PreserveAttribute")!;
var disposableInterfaceSymbol = compilation.GetTypeByMetadataName("System.IDisposable")!;
var httpMethodBaseAttributeSymbol = compilation.GetTypeByMetadataName("Refit.HttpMethodAttribute");

// Get the type names of the attributes we're looking for
var httpMethodAttibutes = new HashSet<ISymbol>(SymbolEqualityComparer.Default)
if(httpMethodBaseAttributeSymbol == null)
{
compilation.GetTypeByMetadataName("Refit.GetAttribute"),
compilation.GetTypeByMetadataName("Refit.HeadAttribute"),
compilation.GetTypeByMetadataName("Refit.PostAttribute"),
compilation.GetTypeByMetadataName("Refit.PutAttribute"),
compilation.GetTypeByMetadataName("Refit.DeleteAttribute"),
compilation.GetTypeByMetadataName("Refit.PatchAttribute"),
compilation.GetTypeByMetadataName("Refit.OptionsAttribute")
};
context.ReportDiagnostic(Diagnostic.Create(RefitNotReferenced, null));
return;
}

// Check the candidates and keep the ones we're actually interested in
var methodSymbols = new List<IMethodSymbol>();
foreach (var method in receiver.CandidateMethods)
{
var model = compilation.GetSemanticModel(method.SyntaxTree);

// Get the symbol being declared by the method
var methodSymbol = model.GetDeclaredSymbol(method);
if (IsRefitMethod(methodSymbol, httpMethodAttibutes))
if (IsRefitMethod(methodSymbol, httpMethodBaseAttributeSymbol))
{
methodSymbols.Add(methodSymbol);
methodSymbols.Add(methodSymbol!);
}
}

Expand All @@ -109,7 +112,7 @@ sealed class PreserveAttribute : Attribute
// with a refit attribute on them. Types may contain other members, without the attribute, which we'll
// need to check for and error out on

var classSource = ProcessInterface(group.Key, group.ToList(), preserveAttributeSymbol, disposableInterfaceSymbol, httpMethodAttibutes, context);
var classSource = ProcessInterface(group.Key, group.ToList(), preserveAttributeSymbol, disposableInterfaceSymbol, httpMethodBaseAttributeSymbol, context);

var keyName = group.Key.Name;
if(keyCount.TryGetValue(keyName, out var value))
Expand All @@ -127,7 +130,7 @@ string ProcessInterface(INamedTypeSymbol interfaceSymbol,
List<IMethodSymbol> refitMethods,
ISymbol preserveAttributeSymbol,
ISymbol disposableInterfaceSymbol,
HashSet<ISymbol> httpMethodAttributeSymbols,
INamedTypeSymbol httpMethodBaseAttributeSymbol,
GeneratorExecutionContext context)
{

Expand Down Expand Up @@ -193,7 +196,7 @@ partial class AutoGenerated{classDeclaration}
}

// Pull out the refit methods from the derived types
var derivedRefitMethods = derivedMethods.Where(m => IsRefitMethod(m, httpMethodAttributeSymbols)).ToList();
var derivedRefitMethods = derivedMethods.Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)).ToList();
var derivedNonRefitMethods = derivedMethods.Except(derivedMethods, SymbolEqualityComparer.Default).Cast<IMethodSymbol>().ToList();

// Handle Refit Methods
Expand Down Expand Up @@ -367,9 +370,9 @@ void WriteMethodOpening(StringBuilder source, IMethodSymbol methodSymbol)
void WriteMethodClosing(StringBuilder source) => source.Append(@" }");


bool IsRefitMethod(IMethodSymbol methodSymbol, HashSet<ISymbol> httpMethodAttibutes)
bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttibute)
{
return methodSymbol.GetAttributes().Any(ad => httpMethodAttibutes.Contains(ad.AttributeClass));
return methodSymbol?.GetAttributes().Any(ad => ad.AttributeClass?.InheritsFromOrEquals(httpMethodAttibute) == true) == true;
}

public void Initialize(GeneratorInitializationContext context)
Expand Down

0 comments on commit 88a8573

Please sign in to comment.