diff --git a/InterfaceStubGenerator.Core/InterfaceStubGenerator.cs b/InterfaceStubGenerator.Core/InterfaceStubGenerator.cs index 328021c6d..0aa906039 100644 --- a/InterfaceStubGenerator.Core/InterfaceStubGenerator.cs +++ b/InterfaceStubGenerator.Core/InterfaceStubGenerator.cs @@ -106,14 +106,46 @@ sealed class PreserveAttribute : Attribute var keyCount = new Dictionary(); + var interfaces = methodSymbols.GroupBy(m => m.ContainingType).ToDictionary(g => g.Key, v => v.ToList()); + + + + + // Look through the candidate interfaces + var interfaceSymbols = new List(); + foreach(var iface in receiver.CandidateInterfaces) + { + var model = compilation.GetSemanticModel(iface.SyntaxTree); + + // get the symbol belonging to the interface + var ifaceSymbol = model.GetDeclaredSymbol(iface); + + // See if we already know about it, might be a dup + if (ifaceSymbol is null || interfaces.ContainsKey(ifaceSymbol)) + continue; + + // The interface has no refit methods, but its base interfaces might + var hasDerivedRefit = ifaceSymbol.AllInterfaces + .SelectMany(i => i.GetMembers().OfType()) + .Where(m => IsRefitMethod(m, httpMethodBaseAttributeSymbol)) + .Any(); + + if(hasDerivedRefit) + { + // Add the interface to the generation list with an empty set of methods + // The logic already looks for base refit methods + interfaces.Add(ifaceSymbol, new List()); + } + } + // group the fields by interface and generate the source - foreach (var group in methodSymbols.GroupBy(m => m.ContainingType)) + foreach (var group in interfaces) { // each group is keyed by the Interface INamedTypeSymbol and contains the members // with a refit attribute on them. Types may contain other members, without the attribute, which we'll // need to check for and error out on - var classSource = ProcessInterface(group.Key, group.ToList(), preserveAttributeSymbol, disposableInterfaceSymbol, httpMethodBaseAttributeSymbol, context); + var classSource = ProcessInterface(group.Key, group.Value, preserveAttributeSymbol, disposableInterfaceSymbol, httpMethodBaseAttributeSymbol, context); var keyName = group.Key.Name; if(keyCount.TryGetValue(keyName, out var value)) @@ -123,7 +155,7 @@ sealed class PreserveAttribute : Attribute keyCount[keyName] = value; context.AddSource($"{keyName}_refit.cs", SourceText.From(classSource, Encoding.UTF8)); - } + } } @@ -390,6 +422,8 @@ class SyntaxReceiver : ISyntaxReceiver { public List CandidateMethods { get; } = new(); + public List CandidateInterfaces { get; } = new(); + public void OnVisitSyntaxNode(SyntaxNode syntaxNode) { // We're looking for methods with an attribute that are in an interfaces @@ -399,6 +433,13 @@ methodDeclarationSyntax.Parent is InterfaceDeclarationSyntax && { CandidateMethods.Add(methodDeclarationSyntax); } + + // We also look for interfaces that derive from others, so we can see if any base methods contain + // Refit methods + if(syntaxNode is InterfaceDeclarationSyntax iface && iface.BaseList is not null) + { + CandidateInterfaces.Add(iface); + } } } } diff --git a/Refit.Tests/InheritedInterfacesApi.cs b/Refit.Tests/InheritedInterfacesApi.cs index 233d55999..a629979b4 100644 --- a/Refit.Tests/InheritedInterfacesApi.cs +++ b/Refit.Tests/InheritedInterfacesApi.cs @@ -47,6 +47,11 @@ public interface IAmInterfaceF_RequireUsing [Get("/get-requiring-using")] Task Get(List guids); } + + public interface IContainAandB : IAmInterfaceB, IAmInterfaceA + { + + } } namespace Refit.Tests.SeparateNamespaceWithModel diff --git a/Refit.Tests/RestService.cs b/Refit.Tests/RestService.cs index 07e539831..c6b3fbcb9 100644 --- a/Refit.Tests/RestService.cs +++ b/Refit.Tests/RestService.cs @@ -1685,6 +1685,31 @@ public async Task InheritedMethodTest() mockHttp.VerifyNoOutstandingExpectation(); } + + [Fact] + public async Task InheritedInterfaceWithOnlyBaseMethodsTest() + { + var mockHttp = new MockHttpMessageHandler(); + + var settings = new RefitSettings + { + HttpMessageHandlerFactory = () => mockHttp + }; + + var fixture = RestService.For("https://httpbin.org", settings); + + mockHttp.Expect(HttpMethod.Get, "https://httpbin.org/get").Respond("application/json", nameof(IAmInterfaceA.Ping)); + var resp = await fixture.Ping(); + Assert.Equal(nameof(IAmInterfaceA.Ping), resp); + mockHttp.VerifyNoOutstandingExpectation(); + + mockHttp.Expect(HttpMethod.Get, "https://httpbin.org/get") + .Respond("application/json", nameof(IAmInterfaceB.Pong)); + resp = await fixture.Pong(); + Assert.Equal(nameof(IAmInterfaceB.Pong), resp); + mockHttp.VerifyNoOutstandingExpectation(); + } + [Fact] public async Task DictionaryDynamicQueryparametersTest() {