Skip to content

Commit 73c804c

Browse files
jtschusterAaronRobinsonMSFT
authored andcommitted
Enable derived COM interfaces to hide base methods with the new keyword (dotnet#101577)
When a COM interface method is defined with the new keyword, we should avoid creating a new shadowing method on the derived interface, and remove the new keyword when creating implementations of methods. This change tracks whether a method was declared with the new keyword. If it was, we search all the inherited methods for one with the same name and signature. If found, we mark that method as "HiddenOnDerivedInterface". The hidden method may not be found if the method hides a non-COM method. When generating methods, we only generate new shadowing method declarations for inherited methods that aren't already hidden, and we don't generate any stubs for inherited hidden methods. The diff for the generated code for the new test interfaces is shown below. --------- Co-authored-by: Aaron Robinson <arobins@microsoft.com>
1 parent 575604f commit 73c804c

File tree

8 files changed

+228
-14
lines changed

8 files changed

+228
-14
lines changed

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
1717
/// </summary>
1818
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);
1919

20+
/// <summary>
21+
/// COM methods that require shadowing declarations on the derived interface.
22+
/// </summary>
23+
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);
24+
2025
/// <summary>
2126
/// COM methods that are declared on an interface the interface inherits from.
2227
/// </summary>
23-
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
28+
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
2429
}
2530
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs

+37-6
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
127127
.WithComparer(SyntaxEquivalentComparer.Instance)
128128
.SelectNormalized();
129129

130-
var shadowingMethods = interfaceAndMethodsContexts
130+
var shadowingMethodDeclarations = interfaceAndMethodsContexts
131131
.Select((data, ct) =>
132132
{
133133
var context = data.Interface.Info;
@@ -163,7 +163,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
163163
.Zip(nativeToManagedVtableMethods)
164164
.Zip(nativeToManagedVtables)
165165
.Zip(iUnknownDerivedAttributeApplication)
166-
.Zip(shadowingMethods)
166+
.Zip(shadowingMethodDeclarations)
167167
.Select(static (data, ct) =>
168168
{
169169
var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data;
@@ -352,7 +352,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
352352

353353
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
354354

355-
var methodSyntaxTemplate = new ContainingSyntax(syntax.Modifiers.StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
355+
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
356356

357357
ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
358358
suppressGCTransitionAttribute,
@@ -423,11 +423,42 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
423423
var methodList = ImmutableArray.CreateBuilder<ComMethodContext>();
424424
while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface)
425425
{
426+
var method = methods[methodIndex];
427+
if (method.MethodInfo.IsUserDefinedShadowingMethod)
428+
{
429+
bool shadowFound = false;
430+
int shadowIndex = -1;
431+
// Don't remove method, but make it so that it doesn't generate any stubs
432+
for (int i = methodList.Count - 1; i > -1; i--)
433+
{
434+
var potentialShadowedMethod = methodList[i];
435+
if (MethodEquals(method, potentialShadowedMethod))
436+
{
437+
shadowFound = true;
438+
shadowIndex = i;
439+
break;
440+
}
441+
}
442+
if (shadowFound)
443+
{
444+
methodList[shadowIndex].IsHiddenOnDerivedInterface = true;
445+
}
446+
// We might not find the shadowed method if it's defined on a non-GeneratedComInterface-attributed interface. Thats okay and we can disregard it.
447+
}
426448
methodList.Add(methods[methodIndex++]);
427449
}
428450
contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual()));
429451
}
430452
return contextList.ToImmutable();
453+
454+
static bool MethodEquals(ComMethodContext a, ComMethodContext b)
455+
{
456+
if (a.MethodInfo.MethodName != b.MethodInfo.MethodName)
457+
return false;
458+
if (a.GenerationContext.SignatureContext.ManagedParameters.SequenceEqual(b.GenerationContext.SignatureContext.ManagedParameters))
459+
return true;
460+
return false;
461+
}
431462
}
432463

433464
private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation")
@@ -436,12 +467,12 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
436467
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
437468
{
438469
var definingType = interfaceGroup.Interface.Info.Type;
439-
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
470+
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
440471
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
441472
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
442473
.WithExplicitInterfaceSpecifier(
443474
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
444-
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
475+
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
445476
return ImplementationInterfaceTemplate
446477
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
447478
.WithMembers(
@@ -560,7 +591,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
560591
ParenthesizedExpression(
561592
BinaryExpression(SyntaxKind.MultiplyExpression,
562593
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
563-
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3))))))));
594+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
564595
}
565596

566597
var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments(

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, In
6969

7070
public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;
7171

72+
public bool IsHiddenOnDerivedInterface { get; set; }
73+
7274
private GeneratedMethodContextBase? _managedToUnmanagedStub;
7375

7476
public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub();
7577

7678
private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
7779
{
78-
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
80+
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
7981
{
8082
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
8183
}
@@ -89,7 +91,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
8991

9092
private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
9193
{
92-
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
94+
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
9395
{
9496
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
9597
}

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace Microsoft.Interop
1717
internal sealed record ComMethodInfo(
1818
MethodDeclarationSyntax Syntax,
1919
string MethodName,
20-
SequenceEqualImmutableArray<AttributeInfo> Attributes)
20+
SequenceEqualImmutableArray<AttributeInfo> Attributes,
21+
bool IsUserDefinedShadowingMethod)
2122
{
2223
/// <summary>
2324
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
@@ -123,7 +124,9 @@ internal sealed record ComMethodInfo(
123124
{
124125
attributeInfos.Add(AttributeInfo.From(attr));
125126
}
126-
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual());
127+
128+
bool shadowsBaseMethod = comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.NewKeyword);
129+
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual(), shadowsBaseMethod);
127130
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((comMethodInfo, method));
128131
}
129132
}

src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs

+34
Original file line numberDiff line numberDiff line change
@@ -386,5 +386,39 @@ public void IStringArrayMarshallingFails_Failing()
386386
obj.ByValueOutParam(strings);
387387
});
388388
}
389+
390+
[Fact]
391+
public unsafe void IHideWorksAsExpected()
392+
{
393+
IHide obj = CreateWrapper<HideBaseMethods, IHide3>();
394+
395+
// IHide.SameMethod should be index 3
396+
Assert.Equal(3, obj.SameMethod());
397+
Assert.Equal(4, obj.DifferentMethod());
398+
399+
IHide2 obj2 = (IHide2)obj;
400+
401+
// IHide2.SameMethod should be index 5
402+
Assert.Equal(5, obj2.SameMethod());
403+
Assert.Equal(4, obj2.DifferentMethod());
404+
Assert.Equal(6, obj2.DifferentMethod2());
405+
406+
IHide3 obj3 = (IHide3)obj;
407+
// IHide3.SameMethod should be index 7
408+
Assert.Equal(7, obj3.SameMethod());
409+
Assert.Equal(4, obj3.DifferentMethod());
410+
Assert.Equal(6, obj3.DifferentMethod2());
411+
Assert.Equal(8, obj3.DifferentMethod3());
412+
413+
// Ensure each VTable method points to the correct method on HideBaseMethods
414+
for (int i = 3; i < 9; i++)
415+
{
416+
var (__this, __vtable_native) = ((global::System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)obj3).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IHide3));
417+
int __retVal;
418+
int __invokeRetVal;
419+
__invokeRetVal = ((delegate* unmanaged[MemberFunction]<void*, int*, int>)__vtable_native[i])(__this, &__retVal);
420+
Assert.Equal(i, __retVal);
421+
}
422+
}
389423
}
390424
}

src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs

+82-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Collections.Generic;
66
using System.Diagnostics;
7+
using System.Diagnostics.CodeAnalysis;
78
using System.Linq;
89
using System.Runtime.CompilerServices;
910
using System.Runtime.InteropServices;
@@ -109,8 +110,6 @@ public static StatelessCollectionAllShapes<TManagedElement> AllocateContainerFor
109110
}
110111
""";
111112

112-
public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]";
113-
public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;";
114113
public const string IntMarshaller = """
115114
[CustomMarshaller(typeof(int), MarshalMode.Default, typeof(IntMarshaller))]
116115
internal static class IntMarshaller
@@ -433,6 +432,87 @@ partial interface INativeAPI
433432
{{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
434433
""";
435434

435+
public string DerivedComInterfaceTypeWithShadowingMethod => $$"""
436+
using System.Runtime.CompilerServices;
437+
using System.Runtime.InteropServices;
438+
using System.Runtime.InteropServices.Marshalling;
439+
440+
{{GeneratedComInterface()}}
441+
partial interface IComInterface
442+
{
443+
void Method();
444+
}
445+
{{GeneratedComInterface()}}
446+
partial interface IComInterface2 : IComInterface
447+
{
448+
new void Method();
449+
}
450+
""";
451+
452+
public string DerivedComInterfaceTypeShadowsNonComMethod => $$"""
453+
using System.Runtime.CompilerServices;
454+
using System.Runtime.InteropServices;
455+
using System.Runtime.InteropServices.Marshalling;
456+
457+
{{GeneratedComInterface()}}
458+
partial interface IComInterface
459+
{
460+
void Method();
461+
}
462+
interface IOtherInterface
463+
{
464+
void Method2();
465+
}
466+
{{GeneratedComInterface()}}
467+
partial interface IComInterface2 : IComInterface
468+
{
469+
new void Method2();
470+
}
471+
""";
472+
473+
public string DerivedComInterfaceTypeShadowsComAndNonComMethod => $$"""
474+
using System.Runtime.CompilerServices;
475+
using System.Runtime.InteropServices;
476+
using System.Runtime.InteropServices.Marshalling;
477+
478+
{{GeneratedComInterface()}}
479+
partial interface IComInterface
480+
{
481+
void Method();
482+
}
483+
interface IOtherInterface
484+
{
485+
void Method();
486+
}
487+
{{GeneratedComInterface()}}
488+
partial interface IComInterface2 : IComInterface
489+
{
490+
new void Method();
491+
}
492+
""";
493+
494+
public string DerivedComInterfaceTypeTwoLevelShadows => $$"""
495+
using System.Runtime.CompilerServices;
496+
using System.Runtime.InteropServices;
497+
using System.Runtime.InteropServices.Marshalling;
498+
499+
{{GeneratedComInterface()}}
500+
partial interface IComInterface
501+
{
502+
void Method();
503+
}
504+
{{GeneratedComInterface()}}
505+
partial interface IComInterface1: IComInterface
506+
{
507+
new void Method1();
508+
}
509+
{{GeneratedComInterface()}}
510+
partial interface IComInterface2 : IComInterface1
511+
{
512+
new void Method();
513+
}
514+
""";
515+
436516
public string DerivedComInterfaceType => $$"""
437517
using System.Runtime.CompilerServices;
438518
using System.Runtime.InteropServices;

src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using System.Diagnostics;
77
using System.Runtime.CompilerServices;
88
using System.Threading.Tasks;
9-
using Microsoft.DotNet.XUnitExtensions.Attributes;
109
using Microsoft.Interop.UnitTests;
1110
using Xunit;
1211
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>;
@@ -336,6 +335,10 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
336335
{
337336
CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider());
338337
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType };
338+
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithShadowingMethod };
339+
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsNonComMethod };
340+
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsComAndNonComMethod };
341+
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeTwoLevelShadows};
339342
yield return new object[] { ID(), codeSnippets.DerivedWithParametersDeclaredInOtherNamespace };
340343
yield return new object[] { ID(), codeSnippets.ComInterfaceParameters };
341344
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Runtime.InteropServices;
6+
using System.Runtime.InteropServices.Marshalling;
7+
8+
namespace SharedTypes.ComInterfaces
9+
{
10+
[GeneratedComInterface]
11+
[Guid("023EA72A-ECAA-4B65-9D96-2122CFADE16C")]
12+
internal partial interface IHide
13+
{
14+
int SameMethod();
15+
int DifferentMethod();
16+
}
17+
18+
[GeneratedComInterface]
19+
[Guid("5293B3B1-4994-425C-803E-A21A5011E077")]
20+
internal partial interface IHide2 : IHide
21+
{
22+
new int SameMethod();
23+
int DifferentMethod2();
24+
}
25+
26+
internal interface UnrelatedInterfaceWithSameMethod
27+
{
28+
int SameMethod();
29+
int DifferentMethod3();
30+
}
31+
32+
[GeneratedComInterface]
33+
[Guid("5DD35432-4987-488D-94F1-7682D7E4405C")]
34+
internal partial interface IHide3 : IHide2, UnrelatedInterfaceWithSameMethod
35+
{
36+
new int SameMethod();
37+
new int DifferentMethod3();
38+
}
39+
40+
[GeneratedComClass]
41+
[Guid("2D36BD6D-C80E-4F00-86E9-8D1B4A0CB59A")]
42+
/// <summary>
43+
/// Implements IHides3 and returns the expected VTable index for each method.
44+
/// </summary>
45+
internal partial class HideBaseMethods : IHide3
46+
{
47+
int IHide.SameMethod() => 3;
48+
int IHide.DifferentMethod() => 4;
49+
int IHide2.SameMethod() => 5;
50+
int IHide2.DifferentMethod2() => 6;
51+
int IHide3.SameMethod() => 7;
52+
int IHide3.DifferentMethod3() => 8;
53+
int UnrelatedInterfaceWithSameMethod.SameMethod() => -1;
54+
int UnrelatedInterfaceWithSameMethod.DifferentMethod3() => -1;
55+
}
56+
}

0 commit comments

Comments
 (0)