Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa
/// </summary>
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);

/// <summary>
/// COM methods that require shadowing declarations on the derived interface.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(SyntaxEquivalentComparer.Instance)
.SelectNormalized();

var shadowingMethods = interfaceAndMethodsContexts
var shadowingMethodDeclarations = interfaceAndMethodsContexts
.Select((data, ct) =>
{
var context = data.Interface.Info;
Expand Down Expand Up @@ -163,7 +163,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Zip(nativeToManagedVtableMethods)
.Zip(nativeToManagedVtables)
.Zip(iUnknownDerivedAttributeApplication)
.Zip(shadowingMethods)
.Zip(shadowingMethodDeclarations)
.Select(static (data, ct) =>
{
var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data;
Expand Down Expand Up @@ -352,7 +352,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M

var containingSyntaxContext = new ContainingSyntaxContext(syntax);

var methodSyntaxTemplate = new ContainingSyntax(syntax.Modifiers.StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);
var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList);

ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes(
suppressGCTransitionAttribute,
Expand Down Expand Up @@ -423,11 +423,42 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
var methodList = ImmutableArray.CreateBuilder<ComMethodContext>();
while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface)
{
var method = methods[methodIndex];
if (method.MethodInfo.IsUserDefinedShadowingMethod)
{
bool shadowFound = false;
int shadowIndex = -1;
// Don't remove method, but make it so that it doesn't generate any stubs
for (int i = methodList.Count - 1; i > -1; i--)
{
var potentialShadowedMethod = methodList[i];
if (MethodEquals(method, potentialShadowedMethod))
{
shadowFound = true;
shadowIndex = i;
break;
}
}
if (shadowFound)
{
methodList[shadowIndex].IsHiddenOnDerivedInterface = true;
}
// We might not find the shadowed method if it's defined on a non-GeneratedComInterface-attributed interface. Thats okay and we can disregard it.
}
methodList.Add(methods[methodIndex++]);
}
contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual()));
}
return contextList.ToImmutable();

static bool MethodEquals(ComMethodContext a, ComMethodContext b)
{
if (a.MethodInfo.MethodName != b.MethodInfo.MethodName)
return false;
if (a.GenerationContext.SignatureContext.ManagedParameters.SequenceEqual(b.GenerationContext.SignatureContext.ManagedParameters))
return true;
return false;
}
}

private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation")
Expand All @@ -436,12 +467,12 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
Expand Down Expand Up @@ -560,7 +591,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
ParenthesizedExpression(
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3))))))));
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3))))))));
}

var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, In

public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface;

public bool IsHiddenOnDerivedInterface { get; set; }

private GeneratedMethodContextBase? _managedToUnmanagedStub;

public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub();

private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
}
Expand All @@ -89,7 +91,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub()

private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
{
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional))
if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface)
{
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace Microsoft.Interop
internal sealed record ComMethodInfo(
MethodDeclarationSyntax Syntax,
string MethodName,
SequenceEqualImmutableArray<AttributeInfo> Attributes)
SequenceEqualImmutableArray<AttributeInfo> Attributes,
bool IsUserDefinedShadowingMethod)
{
/// <summary>
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
Expand Down Expand Up @@ -123,7 +124,9 @@ internal sealed record ComMethodInfo(
{
attributeInfos.Add(AttributeInfo.From(attr));
}
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual());

bool shadowsBaseMethod = comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.NewKeyword);
var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual(), shadowsBaseMethod);
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((comMethodInfo, method));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,39 @@ public void IStringArrayMarshallingFails_Failing()
obj.ByValueOutParam(strings);
});
}

[Fact]
public unsafe void IHideWorksAsExpected()
{
IHide obj = CreateWrapper<HideBaseMethods, IHide3>();

// IHide.SameMethod should be index 3
Assert.Equal(3, obj.SameMethod());
Assert.Equal(4, obj.DifferentMethod());

IHide2 obj2 = (IHide2)obj;

// IHide2.SameMethod should be index 5
Assert.Equal(5, obj2.SameMethod());
Assert.Equal(4, obj2.DifferentMethod());
Assert.Equal(6, obj2.DifferentMethod2());

IHide3 obj3 = (IHide3)obj;
// IHide3.SameMethod should be index 7
Assert.Equal(7, obj3.SameMethod());
Assert.Equal(4, obj3.DifferentMethod());
Assert.Equal(6, obj3.DifferentMethod2());
Assert.Equal(8, obj3.DifferentMethod3());

// Ensure each VTable method points to the correct method on HideBaseMethods
for (int i = 3; i < 9; i++)
{
var (__this, __vtable_native) = ((global::System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)obj3).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IHide3));
int __retVal;
int __invokeRetVal;
__invokeRetVal = ((delegate* unmanaged[MemberFunction]<void*, int*, int>)__vtable_native[i])(__this, &__retVal);
Assert.Equal(i, __retVal);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -109,8 +110,6 @@ public static StatelessCollectionAllShapes<TManagedElement> AllocateContainerFor
}
""";

public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]";
public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;";
public const string IntMarshaller = """
[CustomMarshaller(typeof(int), MarshalMode.Default, typeof(IntMarshaller))]
internal static class IntMarshaller
Expand Down Expand Up @@ -433,6 +432,87 @@ partial interface INativeAPI
{{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
""";

public string DerivedComInterfaceTypeWithShadowingMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeShadowsNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method2();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method2();
}
""";

public string DerivedComInterfaceTypeShadowsComAndNonComMethod => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
interface IOtherInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface
{
new void Method();
}
""";

public string DerivedComInterfaceTypeTwoLevelShadows => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

{{GeneratedComInterface()}}
partial interface IComInterface
{
void Method();
}
{{GeneratedComInterface()}}
partial interface IComInterface1: IComInterface
{
new void Method1();
}
{{GeneratedComInterface()}}
partial interface IComInterface2 : IComInterface1
{
new void Method();
}
""";

public string DerivedComInterfaceType => $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft.DotNet.XUnitExtensions.Attributes;
using Microsoft.Interop.UnitTests;
using Xunit;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>;
Expand Down Expand Up @@ -336,6 +335,10 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
{
CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider());
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithShadowingMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsComAndNonComMethod };
yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeTwoLevelShadows};
yield return new object[] { ID(), codeSnippets.DerivedWithParametersDeclaredInOtherNamespace };
yield return new object[] { ID(), codeSnippets.ComInterfaceParameters };
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid("023EA72A-ECAA-4B65-9D96-2122CFADE16C")]
internal partial interface IHide
{
int SameMethod();
int DifferentMethod();
}

[GeneratedComInterface]
[Guid("5293B3B1-4994-425C-803E-A21A5011E077")]
internal partial interface IHide2 : IHide
{
new int SameMethod();
int DifferentMethod2();
}

internal interface UnrelatedInterfaceWithSameMethod
{
int SameMethod();
int DifferentMethod3();
}

[GeneratedComInterface]
[Guid("5DD35432-4987-488D-94F1-7682D7E4405C")]
internal partial interface IHide3 : IHide2, UnrelatedInterfaceWithSameMethod
{
new int SameMethod();
new int DifferentMethod3();
}

[GeneratedComClass]
[Guid("2D36BD6D-C80E-4F00-86E9-8D1B4A0CB59A")]
/// <summary>
/// Implements IHides3 and returns the expected VTable index for each method.
/// </summary>
internal partial class HideBaseMethods : IHide3
{
int IHide.SameMethod() => 3;
int IHide.DifferentMethod() => 4;
int IHide2.SameMethod() => 5;
int IHide2.DifferentMethod2() => 6;
int IHide3.SameMethod() => 7;
int IHide3.DifferentMethod3() => 8;
int UnrelatedInterfaceWithSameMethod.SameMethod() => -1;
int UnrelatedInterfaceWithSameMethod.DifferentMethod3() => -1;
}
}