diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs index 13f91ff7bca34b..e1ee7e6d535c98 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -17,9 +17,14 @@ internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interfa /// </summary> public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod); + /// <summary> + /// COM methods that require shadowing declarations on the derived interface. + /// </summary> + public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod && !m.IsHiddenOnDerivedInterface); + /// <summary> /// COM methods that are declared on an interface the interface inherits from. /// </summary> - public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod); + public IEnumerable<ComMethodContext> InheritedMethods => Methods.Where(m => m.IsInheritedMethod); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index d3c5c631bb0b5f..1c44a5ab0060e9 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -127,7 +127,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithComparer(SyntaxEquivalentComparer.Instance) .SelectNormalized(); - var shadowingMethods = interfaceAndMethodsContexts + var shadowingMethodDeclarations = interfaceAndMethodsContexts .Select((data, ct) => { var context = data.Interface.Info; @@ -163,7 +163,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Zip(nativeToManagedVtableMethods) .Zip(nativeToManagedVtables) .Zip(iUnknownDerivedAttributeApplication) - .Zip(shadowingMethods) + .Zip(shadowingMethodDeclarations) .Select(static (data, ct) => { var ((((((interfaceContext, interfaceInfo), managedToNativeStubs), nativeToManagedStubs), nativeToManagedVtable), iUnknownDerivedAttribute), shadowingMethod) = data; @@ -352,7 +352,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M var containingSyntaxContext = new ContainingSyntaxContext(syntax); - var methodSyntaxTemplate = new ContainingSyntax(syntax.Modifiers.StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList); + var methodSyntaxTemplate = new ContainingSyntax(new SyntaxTokenList(syntax.Modifiers.Where(static m => !m.IsKind(SyntaxKind.NewKeyword))).StripAccessibilityModifiers(), SyntaxKind.MethodDeclaration, syntax.Identifier, syntax.TypeParameterList); ImmutableArray<FunctionPointerUnmanagedCallingConventionSyntax> callConv = VirtualMethodPointerStubGenerator.GenerateCallConvSyntaxFromAttributes( suppressGCTransitionAttribute, @@ -423,11 +423,42 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor var methodList = ImmutableArray.CreateBuilder<ComMethodContext>(); while (methodIndex < methods.Length && methods[methodIndex].OwningInterface == iface) { + var method = methods[methodIndex]; + if (method.MethodInfo.IsUserDefinedShadowingMethod) + { + bool shadowFound = false; + int shadowIndex = -1; + // Don't remove method, but make it so that it doesn't generate any stubs + for (int i = methodList.Count - 1; i > -1; i--) + { + var potentialShadowedMethod = methodList[i]; + if (MethodEquals(method, potentialShadowedMethod)) + { + shadowFound = true; + shadowIndex = i; + break; + } + } + if (shadowFound) + { + methodList[shadowIndex].IsHiddenOnDerivedInterface = true; + } + // We might not find the shadowed method if it's defined on a non-GeneratedComInterface-attributed interface. Thats okay and we can disregard it. + } methodList.Add(methods[methodIndex++]); } contextList.Add(new(iface, methodList.ToImmutable().ToSequenceEqual())); } return contextList.ToImmutable(); + + static bool MethodEquals(ComMethodContext a, ComMethodContext b) + { + if (a.MethodInfo.MethodName != b.MethodInfo.MethodName) + return false; + if (a.GenerationContext.SignatureContext.ManagedParameters.SequenceEqual(b.GenerationContext.SignatureContext.ManagedParameters)) + return true; + return false; + } } private static readonly InterfaceDeclarationSyntax ImplementationInterfaceTemplate = InterfaceDeclaration("InterfaceImplementation") @@ -436,12 +467,12 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _) { var definingType = interfaceGroup.Interface.Info.Type; - var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub)) + var shadowImplementations = interfaceGroup.InheritedMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub)) .Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext) .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node .WithExplicitInterfaceSpecifier( ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName)))); - var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub); + var inheritedStubs = interfaceGroup.InheritedMethods.Select(m => m.UnreachableExceptionStub); return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) .WithMembers( @@ -560,7 +591,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf ParenthesizedExpression( BinaryExpression(SyntaxKind.MultiplyExpression, SizeOfExpression(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword)))), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.ShadowingMethods.Count() + 3)))))))); + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(interfaceMethods.InheritedMethods.Count() + 3)))))))); } var vtableSlotAssignments = VirtualMethodPointerStubGenerator.GenerateVirtualMethodTableSlotAssignments( diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index c67b5029ee756c..25e49f8f817451 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -69,13 +69,15 @@ public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, In public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface; + public bool IsHiddenOnDerivedInterface { get; set; } + private GeneratedMethodContextBase? _managedToUnmanagedStub; public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub(); private GeneratedMethodContextBase CreateManagedToUnmanagedStub() { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface) { return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); } @@ -89,7 +91,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub() private GeneratedMethodContextBase CreateUnmanagedToManagedStub() { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional) || IsHiddenOnDerivedInterface) { return new SkippedStubContext(GenerationContext.OriginalDefiningType); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 5e84e461483a8d..39377c5cd3b44c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -17,7 +17,8 @@ namespace Microsoft.Interop internal sealed record ComMethodInfo( MethodDeclarationSyntax Syntax, string MethodName, - SequenceEqualImmutableArray<AttributeInfo> Attributes) + SequenceEqualImmutableArray<AttributeInfo> Attributes, + bool IsUserDefinedShadowingMethod) { /// <summary> /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa. @@ -123,7 +124,9 @@ internal sealed record ComMethodInfo( { attributeInfos.Add(AttributeInfo.From(attr)); } - var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual()); + + bool shadowsBaseMethod = comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.NewKeyword); + var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name, attributeInfos.MoveToImmutable().ToSequenceEqual(), shadowsBaseMethod); return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From((comMethodInfo, method)); } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs index 765c26ffac59c5..8ba4511fe93d9e 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs @@ -386,5 +386,39 @@ public void IStringArrayMarshallingFails_Failing() obj.ByValueOutParam(strings); }); } + + [Fact] + public unsafe void IHideWorksAsExpected() + { + IHide obj = CreateWrapper<HideBaseMethods, IHide3>(); + + // IHide.SameMethod should be index 3 + Assert.Equal(3, obj.SameMethod()); + Assert.Equal(4, obj.DifferentMethod()); + + IHide2 obj2 = (IHide2)obj; + + // IHide2.SameMethod should be index 5 + Assert.Equal(5, obj2.SameMethod()); + Assert.Equal(4, obj2.DifferentMethod()); + Assert.Equal(6, obj2.DifferentMethod2()); + + IHide3 obj3 = (IHide3)obj; + // IHide3.SameMethod should be index 7 + Assert.Equal(7, obj3.SameMethod()); + Assert.Equal(4, obj3.DifferentMethod()); + Assert.Equal(6, obj3.DifferentMethod2()); + Assert.Equal(8, obj3.DifferentMethod3()); + + // Ensure each VTable method points to the correct method on HideBaseMethods + for (int i = 3; i < 9; i++) + { + var (__this, __vtable_native) = ((global::System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider)obj3).GetVirtualMethodTableInfoForKey(typeof(global::SharedTypes.ComInterfaces.IHide3)); + int __retVal; + int __invokeRetVal; + __invokeRetVal = ((delegate* unmanaged[MemberFunction]<void*, int*, int>)__vtable_native[i])(__this, &__retVal); + Assert.Equal(i, __retVal); + } + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs index 12b59f274425a6..c3115b7f1dbafb 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -109,8 +110,6 @@ public static StatelessCollectionAllShapes<TManagedElement> AllocateContainerFor } """; - public static readonly string DisableRuntimeMarshalling = "[assembly:System.Runtime.CompilerServices.DisableRuntimeMarshalling]"; - public static readonly string UsingSystemRuntimeInteropServicesMarshalling = "using System.Runtime.InteropServices.Marshalling;"; public const string IntMarshaller = """ [CustomMarshaller(typeof(int), MarshalMode.Default, typeof(IntMarshaller))] internal static class IntMarshaller @@ -433,6 +432,87 @@ partial interface INativeAPI {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}} """; + public string DerivedComInterfaceTypeWithShadowingMethod => $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + {{GeneratedComInterface()}} + partial interface IComInterface + { + void Method(); + } + {{GeneratedComInterface()}} + partial interface IComInterface2 : IComInterface + { + new void Method(); + } + """; + + public string DerivedComInterfaceTypeShadowsNonComMethod => $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + {{GeneratedComInterface()}} + partial interface IComInterface + { + void Method(); + } + interface IOtherInterface + { + void Method2(); + } + {{GeneratedComInterface()}} + partial interface IComInterface2 : IComInterface + { + new void Method2(); + } + """; + + public string DerivedComInterfaceTypeShadowsComAndNonComMethod => $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + {{GeneratedComInterface()}} + partial interface IComInterface + { + void Method(); + } + interface IOtherInterface + { + void Method(); + } + {{GeneratedComInterface()}} + partial interface IComInterface2 : IComInterface + { + new void Method(); + } + """; + + public string DerivedComInterfaceTypeTwoLevelShadows => $$""" + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + {{GeneratedComInterface()}} + partial interface IComInterface + { + void Method(); + } + {{GeneratedComInterface()}} + partial interface IComInterface1: IComInterface + { + new void Method1(); + } + {{GeneratedComInterface()}} + partial interface IComInterface2 : IComInterface1 + { + new void Method(); + } + """; + public string DerivedComInterfaceType => $$""" using System.Runtime.CompilerServices; using System.Runtime.InteropServices; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs index feb69cc01da9c6..1522ab3478ac61 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs @@ -6,7 +6,6 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading.Tasks; -using Microsoft.DotNet.XUnitExtensions.Attributes; using Microsoft.Interop.UnitTests; using Xunit; using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>; @@ -336,6 +335,10 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile() { CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider()); yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType }; + yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeWithShadowingMethod }; + yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsNonComMethod }; + yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeShadowsComAndNonComMethod }; + yield return new object[] { ID(), codeSnippets.DerivedComInterfaceTypeTwoLevelShadows}; yield return new object[] { ID(), codeSnippets.DerivedWithParametersDeclaredInOtherNamespace }; yield return new object[] { ID(), codeSnippets.ComInterfaceParameters }; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IHide.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IHide.cs new file mode 100644 index 00000000000000..9d2fc102ba5df4 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IHide.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [GeneratedComInterface] + [Guid("023EA72A-ECAA-4B65-9D96-2122CFADE16C")] + internal partial interface IHide + { + int SameMethod(); + int DifferentMethod(); + } + + [GeneratedComInterface] + [Guid("5293B3B1-4994-425C-803E-A21A5011E077")] + internal partial interface IHide2 : IHide + { + new int SameMethod(); + int DifferentMethod2(); + } + + internal interface UnrelatedInterfaceWithSameMethod + { + int SameMethod(); + int DifferentMethod3(); + } + + [GeneratedComInterface] + [Guid("5DD35432-4987-488D-94F1-7682D7E4405C")] + internal partial interface IHide3 : IHide2, UnrelatedInterfaceWithSameMethod + { + new int SameMethod(); + new int DifferentMethod3(); + } + + [GeneratedComClass] + [Guid("2D36BD6D-C80E-4F00-86E9-8D1B4A0CB59A")] + /// <summary> + /// Implements IHides3 and returns the expected VTable index for each method. + /// </summary> + internal partial class HideBaseMethods : IHide3 + { + int IHide.SameMethod() => 3; + int IHide.DifferentMethod() => 4; + int IHide2.SameMethod() => 5; + int IHide2.DifferentMethod2() => 6; + int IHide3.SameMethod() => 7; + int IHide3.DifferentMethod3() => 8; + int UnrelatedInterfaceWithSameMethod.SameMethod() => -1; + int UnrelatedInterfaceWithSameMethod.DifferentMethod3() => -1; + } +}