Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable derived COM interfaces to hide base methods with the new keyword #101577

Merged
merged 7 commits into from
May 6, 2024
Original file line number Diff line number Diff line change
@@ -17,9 +17,14 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
/// </summary>
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);

/// <summary>
/// COM methods that require shadowing declarations on the derived interface.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
}
}
Original file line number Diff line number Diff line change
@@ -127,7 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(SyntaxEquivalentComparer.Instance)
.SelectNormalized();

var shadowingMethods = interfaceAndMethodsContexts
var shadowingMethodDeclarations = interfaceAndMethodsContexts
.Select((data, ct) =>
{
var context = data.Interface.Info;
@@ -163,7 +163,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Zip(nativeToManagedVtableMethods)
.Zip(nativeToManagedVtables)
.Zip(iUnknownDerivedAttributeApplication)
.Zip(shadowingMethods)
.Zip(shadowingMethodDeclarations)
.Select(static (data, ct) =>
{
var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data;
@@ -352,7 +352,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M

var containingSyntaxContext = new ContainingSyntaxContext(syntax);

var methodSyntaxTemplate = new ContainingSyntax(syntax.Modifiers.StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
suppressGCTransitionAttribute,
@@ -423,11 +423,42 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
var methodList = ImmutableArray.CreateBuilder<ComMethodContext>();
while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface)
{
var method = methods[methodIndex];
if (method.MethodInfo.IsUserDefinedShadowingMethod)
{
bool shadowFound = false;
int shadowIndex = -1;
// Don't remove method, but make it so that it doesn't generate any stubs
for (int i = methodList.Count - 1; i > -1; i--)
{
var potentialShadowedMethod = methodList[i];
if (MethodEquals(method, potentialShadowedMethod))
{
shadowFound = true;
shadowIndex = i;
break;
}
}
if (shadowFound)
{
methodList[shadowIndex].IsHiddenOnDerivedInterface = true;
}
// We might not find the shadowed method if it's defined on a non-GeneratedComInterface-attributed interface. Thats okay and we can disregard it.
}
methodList.Add(methods[methodIndex++]);
}
contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual()));
}
return contextList.ToImmutable();

static bool MethodEquals(ComMethodContext a, ComMethodContext b)
{
if (a.MethodInfo.MethodName != b.MethodInfo.MethodName)
return false;
if (a.GenerationContext.SignatureContext.ManagedParameters.SequenceEqual(b.GenerationContext.SignatureContext.ManagedParameters))
return true;
return false;
}
}

private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation")
@@ -436,12 +467,12 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
@@ -560,7 +591,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
ParenthesizedExpression(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3))))))));
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
}

var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments(
Original file line number Diff line number Diff line change
@@ -69,13 +69,15 @@ public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, In

public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;

public bool IsHiddenOnDerivedInterface { get; set; }

private GeneratedMethodContextBase? _managedToUnmanagedStub;

public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub();

private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
}
@@ -89,7 +91,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub()

private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
}
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@ namespace Microsoft.Interop
internal sealed record ComMethodInfo(
MethodDeclarationSyntax Syntax,
string MethodName,
SequenceEqualImmutableArray<AttributeInfo> Attributes)
SequenceEqualImmutableArray<AttributeInfo> Attributes,
bool IsUserDefinedShadowingMethod)
{
/// <summary>
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
@@ -123,7 +124,9 @@ internal sealed record ComMethodInfo(
{
attributeInfos.Add(AttributeInfo.From(attr));
}
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual());

bool shadowsBaseMethod = comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.NewKeyword);
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual(), shadowsBaseMethod);
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((comMethodInfo, method));
}
}
Original file line number Diff line number Diff line change
@@ -386,5 +386,39 @@ public void IStringArrayMarshallingFails_Failing()
obj.ByValueOutParam(strings);
});
}

[Fact]
public unsafe void IHideWorksAsExpected()
{
IHide obj = CreateWrapper<HideBaseMethods, IHide3>();

// IHide.SameMethod should be index 3
Assert.Equal(3, obj.SameMethod());
Assert.Equal(4, obj.DifferentMethod());

IHide2 obj2 = (IHide2)obj;

// IHide2.SameMethod should be index 5
Assert.Equal(5, obj2.SameMethod());
Assert.Equal(4, obj2.DifferentMethod());
Assert.Equal(6, obj2.DifferentMethod2());

IHide3 obj3 = (IHide3)obj;
// IHide3.SameMethod should be index 7
Assert.Equal(7, obj3.SameMethod());
Assert.Equal(4, obj3.DifferentMethod());
Assert.Equal(6, obj3.DifferentMethod2());
Assert.Equal(8, obj3.DifferentMethod3());

// Ensure each VTable method points to the correct method on HideBaseMethods
for (int i = 3; i < 9; i++)
{
var (__this, __vtable_native) = ((global::System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)obj3).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IHide3));
int __retVal;
int __invokeRetVal;
__invokeRetVal = ((delegate* unmanaged[MemberFunction]<void*, int*, int>)__vtable_native[i])(__this, &__retVal);
Assert.Equal(i, __retVal);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
@@ -109,8 +110,6 @@ public static StatelessCollectionAllShapes<TManagedElement> AllocateContainerFor
}
""";

public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]";
public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;";
public const string IntMarshaller = """
[CustomMarshaller(typeof(int), MarshalMode.Default, typeof(IntMarshaller))]
internal static class IntMarshaller
@@ -433,6 +432,87 @@ partial interface INativeAPI
{{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
""";

public string DerivedComInterfaceTypeWithShadowingMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeShadowsNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method2();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method2();
}
""";

public string DerivedComInterfaceTypeShadowsComAndNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeTwoLevelShadows => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface1: IComInterface
{
new void Method1();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface1
{
new void Method();
}
""";

public string DerivedComInterfaceType => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft.DotNet.XUnitExtensions.Attributes;
using Microsoft.Interop.UnitTests;
using Xunit;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>;
@@ -336,6 +335,10 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
{
CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider());
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithShadowingMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsComAndNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeTwoLevelShadows};
yield return new object[] { ID(), codeSnippets.DerivedWithParametersDeclaredInOtherNamespace };
yield return new object[] { ID(), codeSnippets.ComInterfaceParameters };
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid("023EA72A-ECAA-4B65-9D96-2122CFADE16C")]
internal partial interface IHide
{
int SameMethod();
int DifferentMethod();
}

[GeneratedComInterface]
[Guid("5293B3B1-4994-425C-803E-A21A5011E077")]
internal partial interface IHide2 : IHide
{
new int SameMethod();
int DifferentMethod2();
}

internal interface UnrelatedInterfaceWithSameMethod
{
int SameMethod();
int DifferentMethod3();
}

[GeneratedComInterface]
[Guid("5DD35432-4987-488D-94F1-7682D7E4405C")]
internal partial interface IHide3 : IHide2, UnrelatedInterfaceWithSameMethod
{
new int SameMethod();
new int DifferentMethod3();
}

[GeneratedComClass]
[Guid("2D36BD6D-C80E-4F00-86E9-8D1B4A0CB59A")]
/// <summary>
/// Implements IHides3 and returns the expected VTable index for each method.
/// </summary>
internal partial class HideBaseMethods : IHide3
{
int IHide.SameMethod() => 3;
int IHide.DifferentMethod() => 4;
int IHide2.SameMethod() => 5;
int IHide2.DifferentMethod2() => 6;
int IHide3.SameMethod() => 7;
int IHide3.DifferentMethod3() => 8;
int UnrelatedInterfaceWithSameMethod.SameMethod() => -1;
int UnrelatedInterfaceWithSameMethod.DifferentMethod3() => -1;
}
}