Skip to content

Commit 670d8d5

Browse files
authored
feat: support inherited interfaces (#59)
This PR adds support for inherited interfaces in the Mockerade source generator, allowing it to generate mocks for types that implement multiple interfaces or inherit from other interfaces. ### Key changes: - Enhanced type traversal to include all base types and interfaces when collecting mockable members - Added explicit interface implementation support for handling method name conflicts - Extended indexer property support with parameter handling
1 parent c4b045b commit 670d8d5

File tree

8 files changed

+143
-29
lines changed

8 files changed

+143
-29
lines changed

Source/Mockerade.SourceGenerators/Entities/Class.cs

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,54 @@ namespace Mockerade.SourceGenerators.Entities;
55

66
internal record Class
77
{
8+
private string GetTypeName(ITypeSymbol type, List<string> additionalNamespaces)
9+
{
10+
if (type is INamedTypeSymbol namedType)
11+
{
12+
if (namedType.IsGenericType)
13+
{
14+
additionalNamespaces.AddRange(namedType.TypeArguments
15+
.Select(t => t.ContainingNamespace.ToString()));
16+
return namedType.Name + "<" + string.Join(",", namedType.TypeArguments.Select(t => GetTypeName(t, additionalNamespaces))) + ">";
17+
}
18+
return namedType.SpecialType switch
19+
{
20+
SpecialType.System_Int32 => "int",
21+
SpecialType.System_Int64 => "long",
22+
SpecialType.System_Int16 => "short",
23+
SpecialType.System_UInt32 => "uint",
24+
SpecialType.System_UInt64 => "ulong",
25+
SpecialType.System_UInt16 => "ushort",
26+
_ => type.Name
27+
};
28+
}
29+
else
30+
{
31+
return type.Name;
32+
}
33+
}
34+
public static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(ITypeSymbol type)
35+
{
36+
var current = type;
37+
while (current != null)
38+
{
39+
yield return current;
40+
if (current.TypeKind == TypeKind.Interface)
41+
{
42+
foreach (var @interface in current.Interfaces)
43+
{
44+
yield return @interface;
45+
}
46+
}
47+
current = current.BaseType;
48+
}
49+
}
50+
851
public Class(ITypeSymbol type)
952
{
53+
List<string> additionalNamespaces = [];
1054
Namespace = type.ContainingNamespace.ToString();
11-
ClassName = type.Name;
55+
ClassName = GetTypeName(type, additionalNamespaces);
1256

1357
if (type.ContainingType is not null)
1458
{
@@ -17,30 +61,53 @@ public Class(ITypeSymbol type)
1761
}
1862

1963
IsInterface = type.TypeKind == TypeKind.Interface;
20-
Methods = new EquatableArray<Method>(
21-
type.GetMembers().OfType<IMethodSymbol>()
64+
var y = GetBaseTypesAndThis(type).SelectMany(t => t.GetMembers().OfType<IMethodSymbol>())
65+
// Exclude getter/setter methods
66+
.Where(x => x.AssociatedSymbol is null && !x.IsSealed)
67+
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract).ToList();
68+
var methods = GetBaseTypesAndThis(type).SelectMany(t => t.GetMembers().OfType<IMethodSymbol>())
2269
// Exclude getter/setter methods
23-
.Where(x => x.AssociatedSymbol is null)
70+
.Where(x => x.AssociatedSymbol is null && !x.IsSealed)
2471
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
2572
.Select(x => new Method(x))
26-
.ToArray());
73+
.Distinct()
74+
.ToList();
75+
for (int i = 0; i < methods.Count; i++)
76+
{
77+
var method = methods[i];
78+
if (methods.Take(i)
79+
.Any(m =>
80+
m.Name == method.Name &&
81+
m.Parameters.Count == method.Parameters.Count &&
82+
m.Parameters.SequenceEqual(method.Parameters)))
83+
{
84+
methods[i] = method with { ExplicitImplementation = method.ContainingType };
85+
}
86+
}
87+
Methods = new EquatableArray<Method>(methods.ToArray());
2788
Properties = new EquatableArray<Property>(
28-
type.GetMembers().OfType<IPropertySymbol>()
89+
GetBaseTypesAndThis(type).SelectMany(t => t.GetMembers().OfType<IPropertySymbol>())
90+
.Where(x => !x.IsSealed)
2991
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
3092
.Select(x => new Property(x))
93+
.Distinct()
3194
.ToArray());
3295
Events = new EquatableArray<Event>(
33-
type.GetMembers().OfType<IEventSymbol>()
96+
GetBaseTypesAndThis(type).SelectMany(t => t.GetMembers().OfType<IEventSymbol>())
97+
.Where(x => !x.IsSealed)
3498
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
3599
.Select(x => (x, (x.Type as INamedTypeSymbol)?.DelegateInvokeMethod))
36100
.Where(x => x.DelegateInvokeMethod is not null)
37101
.Select(x => new Event(x.x, x.DelegateInvokeMethod!))
102+
.Distinct()
38103
.ToArray());
104+
AdditionalNamespaces = new EquatableArray<string>(additionalNamespaces.Distinct().ToArray());
39105
}
40106

41107
public Type? ContainingType { get; }
42108

43109
public EquatableArray<Method> Methods { get; }
110+
public EquatableArray<string> AdditionalNamespaces { get; }
44111

45112
public EquatableArray<Property> Properties { get; }
46113

@@ -51,7 +118,10 @@ public Class(ITypeSymbol type)
51118
public string ClassName { get; }
52119

53120
public string GetClassNameWithoutDots()
54-
=> ClassName.Replace(".", "");
121+
=> ClassName
122+
.Replace(".", "")
123+
.Replace("<", "")
124+
.Replace(">", "");
55125

56126
public string[] GetClassNamespaces() => EnumerateNamespaces().Distinct().OrderBy(n => n).ToArray();
57127
internal IEnumerable<string> EnumerateNamespaces()
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
using System.Text;
2-
using Mockerade.SourceGenerators.Internals;
1+
using Mockerade.SourceGenerators.Internals;
32
using Microsoft.CodeAnalysis;
43

54
namespace Mockerade.SourceGenerators.Entities;
65

7-
internal readonly record struct Method
6+
internal record struct Method
87
{
98
public Method(IMethodSymbol methodSymbol)
109
{
1110
Accessibility = methodSymbol.DeclaredAccessibility;
1211
UseOverride = methodSymbol.IsVirtual || methodSymbol.IsAbstract;
1312
ReturnType = methodSymbol.ReturnsVoid ? Type.Void : new Type(methodSymbol.ReturnType);
1413
Name = methodSymbol.Name;
14+
ContainingType = methodSymbol.ContainingType.Name;
1515
Parameters = new EquatableArray<MethodParameter>(
1616
methodSymbol.Parameters.Select(x => new MethodParameter(x)).ToArray());
1717
}
@@ -21,5 +21,7 @@ public Method(IMethodSymbol methodSymbol)
2121
public Accessibility Accessibility { get; }
2222
public Type ReturnType { get; }
2323
public string Name { get; }
24+
public string ContainingType { get; }
2425
public EquatableArray<MethodParameter> Parameters { get; }
26+
public string? ExplicitImplementation { get; set; }
2527
}

Source/Mockerade.SourceGenerators/Entities/Property.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ public Property(IPropertySymbol propertySymbol)
1212
UseOverride = propertySymbol.IsVirtual || propertySymbol.IsAbstract;
1313
Name = propertySymbol.Name;
1414
Type = new Type(propertySymbol.Type);
15+
IsIndexer = propertySymbol.IsIndexer;
16+
if (IsIndexer && propertySymbol.Parameters.Length > 0)
17+
{
18+
IndexerParameter = new MethodParameter(propertySymbol.Parameters[0]);
19+
}
1520
Getter = propertySymbol.GetMethod is null ? null : new Method(propertySymbol.GetMethod);
1621
Setter = propertySymbol.SetMethod is null ? null : new Method(propertySymbol.SetMethod);
1722
}
1823

24+
public bool IsIndexer { get; }
25+
public MethodParameter? IndexerParameter { get; }
1926
public Type Type { get; }
2027

2128
public Method? Setter { get; }

Source/Mockerade.SourceGenerators/Entities/Type.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ private Type(string fullname)
1212
internal Type(ITypeSymbol typeSymbol)
1313
{
1414
Fullname = typeSymbol.ToDisplayString();
15-
Namespace = typeSymbol.ContainingNamespace.ToString();
15+
Namespace = typeSymbol.ContainingNamespace?.ToString();
1616
}
1717

1818
public string? Namespace { get; }

Source/Mockerade.SourceGenerators/Internals/GeneratorHelpers.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ internal static bool IsMockForInvocationExpressionSyntax(this SyntaxNode node)
1111
{
1212
Expression: MemberAccessExpressionSyntax
1313
{
14-
Expression: IdentifierNameSyntax { Identifier.Text: "Mock", },
1514
Name: GenericNameSyntax { Identifier.Text : "For", }
1615
}
1716
};
@@ -26,7 +25,8 @@ internal static bool TryExtractGenericNameSyntax(this SyntaxNode syntaxNode,
2625
ISymbol? symbol = semanticModel.GetSymbolInfo(syntaxNode).Symbol;
2726
genericNameSyntax = value;
2827
return symbol?.ContainingType.ContainingNamespace.ContainingNamespace.IsGlobalNamespace == true &&
29-
symbol.ContainingType.ContainingNamespace.Name == "Mockerade";
28+
symbol.ContainingType.ContainingNamespace.Name == "Mockerade" &&
29+
symbol.ContainingType.Name == "Mock";
3030
}
3131

3232
genericNameSyntax = null;

Source/Mockerade.SourceGenerators/Internals/SourceGeneration.ExtensionsClass.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ private static void AppendRaisesExtensions(StringBuilder sb, Class @class, strin
104104
private static void AppendSetupExtensions(StringBuilder sb, Class @class, string[] namespaces, bool isProtected = false)
105105
{
106106
var methodPredicate = isProtected
107-
? new Func<Method, bool>(e => e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
108-
: new Func<Method, bool>(e => e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
107+
? new Func<Method, bool>(e => e.ExplicitImplementation is null && e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
108+
: new Func<Method, bool>(e => e.ExplicitImplementation is null && e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
109109
var propertyPredicate = isProtected
110-
? new Func<Property, bool>(e => e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
111-
: new Func<Property, bool>(e => e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
110+
? new Func<Property, bool>(e => !e.IsIndexer && e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
111+
: new Func<Property, bool>(e => !e.IsIndexer && e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
112112
if (!@class.Properties.Any(propertyPredicate) &&
113113
!@class.Methods.Any(methodPredicate))
114114
{
@@ -128,7 +128,7 @@ private static void AppendSetupExtensions(StringBuilder sb, Class @class, string
128128
sb.Append("\t\t/// Setup for the property <see cref=\"").Append(@class.ClassName).Append(".").Append(property.Name).Append("\"/>.").AppendLine();
129129
sb.Append("\t\t/// </summary>").AppendLine();
130130
sb.Append("\t\tpublic PropertySetup<").Append(property.Type.GetMinimizedString(namespaces)).Append("> ")
131-
.Append(property.Name).AppendLine();
131+
.Append((property.IndexerParameter is not null ? property.Name.Replace("[]", $"[With.Parameter<{property.IndexerParameter.Value.Type.GetMinimizedString(namespaces)}> {property.IndexerParameter.Value.Name}]") : property.Name)).AppendLine();
132132

133133
sb.AppendLine("\t\t{");
134134
sb.AppendLine("\t\t\tget");
@@ -260,8 +260,8 @@ private static void AppendSetupExtensions(StringBuilder sb, Class @class, string
260260
private static void AppendInvokedExtensions(StringBuilder sb, Class @class, string[] namespaces, bool isProtected = false)
261261
{
262262
var predicate = isProtected
263-
? new Func<Method, bool>(e => e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
264-
: new Func<Method, bool>(e => e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
263+
? new Func<Method, bool>(e => e.ExplicitImplementation is null && e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
264+
: new Func<Method, bool>(e => e.ExplicitImplementation is null && e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
265265
if (!@class.Methods.Any(predicate))
266266
{
267267
return;
@@ -313,8 +313,8 @@ private static void AppendInvokedExtensions(StringBuilder sb, Class @class, stri
313313
private static void AppendAccessedExtensions(StringBuilder sb, Class @class, string[] namespaces, bool isProtected = false)
314314
{
315315
var predicate = isProtected
316-
? new Func<Property, bool>(e => e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
317-
: new Func<Property, bool>(e => e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
316+
? new Func<Property, bool>(e => !e.IsIndexer && e.Accessibility is Accessibility.Protected or Accessibility.ProtectedOrInternal)
317+
: new Func<Property, bool>(e => !e.IsIndexer && e.Accessibility is not (Accessibility.Protected or Accessibility.ProtectedOrInternal));
318318
if (!@class.Properties.Any(predicate))
319319
{
320320
return;

Source/Mockerade.SourceGenerators/Internals/SourceGeneration.MockClass.cs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ private static void ImplementClass(StringBuilder sb, Class @class, string[] name
241241
sb.Append("override ");
242242
}
243243
sb.Append(property.Type.GetMinimizedString(namespaces))
244-
.Append(" ").Append(property.Name).AppendLine();
244+
.Append(" ").Append((property.IndexerParameter is not null ? property.Name.Replace("[]", $"[{property.IndexerParameter.Value.Type.GetMinimizedString(namespaces)} {property.IndexerParameter.Value.Name}]") : property.Name)).AppendLine();
245245
}
246246
sb.AppendLine("\t\t{");
247247
if (property.Getter != null && property.Getter.Value.Accessibility != Microsoft.CodeAnalysis.Accessibility.Private)
@@ -292,13 +292,22 @@ private static void ImplementClass(StringBuilder sb, Class @class, string[] name
292292
}
293293
else
294294
{
295-
sb.Append("\t\t").Append(method.Accessibility.ToVisibilityString()).Append(' ');
296-
if (!@class.IsInterface && method.UseOverride)
295+
sb.Append("\t\t");
296+
if (method.ExplicitImplementation is null)
297297
{
298-
sb.Append("override ");
298+
sb.Append(method.Accessibility.ToVisibilityString()).Append(' ');
299+
if (!@class.IsInterface && method.UseOverride)
300+
{
301+
sb.Append("override ");
302+
}
303+
sb.Append(method.ReturnType.GetMinimizedString(namespaces)).Append(' ')
304+
.Append(method.Name).Append('(');
305+
}
306+
else
307+
{
308+
sb.Append(method.ReturnType.GetMinimizedString(namespaces)).Append(' ')
309+
.Append(method.ExplicitImplementation).Append('.').Append(method.Name).Append('(');
299310
}
300-
sb.Append(method.ReturnType.GetMinimizedString(namespaces)).Append(' ')
301-
.Append(method.Name).Append('(');
302311
}
303312
int index = 0;
304313
foreach (MethodParameter parameter in method.Parameters)

Tests/Mockerade.SourceGenerators.Tests/Test.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,32 @@ public static void Main(string[] args)
2121

2222
await That(result.Diagnostics).IsEmpty();
2323

24+
await That(result.Sources)!.All().AreEquivalentTo(new
25+
{
26+
Source = It.Is<string>().That.Contains("<auto-generated>").And.Contains("</auto-generated>"),
27+
});
28+
}
29+
[Fact]
30+
public async Task XXX()
31+
{
32+
GeneratorResult result = Generator
33+
.Run("""
34+
using System.Collections.Generic;
35+
36+
namespace MyCode
37+
{
38+
public class Program
39+
{
40+
public static void Main(string[] args)
41+
{
42+
var x = Mockerade.Mock.For<IList<int>>();
43+
}
44+
}
45+
}
46+
""");
47+
48+
await That(result.Diagnostics).IsEmpty();
49+
2450
await That(result.Sources)!.All().AreEquivalentTo(new
2551
{
2652
Source = It.Is<string>().That.Contains("<auto-generated>").And.Contains("</auto-generated>"),

0 commit comments

Comments
 (0)