From 38a732b027acdb1b4b0770f42cf582bb595e4a3a Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Wed, 22 Nov 2023 16:24:32 +0100 Subject: [PATCH] Skip interfaces not publicly accessible in authoring scenarios (#1394) * Add internal COM interfaces to authoring test * Add MixedWinRTClassicCOM authoring tests * Skip interface types not publicly accessible * Minor code refactoring * Suppress diagnostics for not publicly accessible types * Use fully qualified name for [Guid] to avoid conflicts * Remmove collection expressions in projection attributes * Skip processing explicit members of internal interfaces * Skip processing symbols nested in internal types * Add ABI types for AOT generator * Restore original order/filtering to gather interfaces * Fix build errors in AuthoringConsumptionTest * Add TestMixedWinRTCOMWrapper to activation manifest * Fix unit test * Fix typos in ABI method names --- .../WinRT.SourceGenerator/DiagnosticUtils.cs | 6 +- .../Extensions/SymbolExtensions.cs | 54 +++++++ .../WinRT.SourceGenerator/WinRTTypeWriter.cs | 59 ++++++-- .../AuthoringConsumptionTest.exe.manifest | 4 + src/Tests/AuthoringConsumptionTest/pch.h | 1 + src/Tests/AuthoringConsumptionTest/test.cpp | 34 +++++ src/Tests/AuthoringTest/Program.cs | 141 ++++++++++++++++++ 7 files changed, 283 insertions(+), 16 deletions(-) create mode 100644 src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs diff --git a/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs b/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs index 73e099d5d..9c2215de1 100644 --- a/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs +++ b/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs @@ -87,10 +87,12 @@ private void CheckDeclarations() foreach (var declaration in syntaxReceiver.Declarations) { var model = _context.Compilation.GetSemanticModel(declaration.SyntaxTree); + var symbol = model.GetDeclaredSymbol(declaration); // Check symbol information for whether it is public to properly detect partial types - // which can leave out modifier. - if (model.GetDeclaredSymbol(declaration).DeclaredAccessibility != Accessibility.Public) + // which can leave out modifier. Also ignore nested types not effectively public + if (symbol.DeclaredAccessibility != Accessibility.Public || + (symbol is ITypeSymbol typeSymbol && !typeSymbol.IsPubliclyAccessible())) { continue; } diff --git a/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs b/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs new file mode 100644 index 000000000..5a8f443d2 --- /dev/null +++ b/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs @@ -0,0 +1,54 @@ +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; + +#nullable enable + +namespace Generator; + +/// +/// Extensions for symbol types. +/// +internal static class SymbolExtensions +{ + /// + /// Checks whether a given type symbol is publicly accessible (ie. it's public and not nested in any non public type). + /// + /// The type symbol to check for public accessibility. + /// Whether is publicly accessible. + public static bool IsPubliclyAccessible(this ITypeSymbol type) + { + for (ITypeSymbol? currentType = type; currentType is not null; currentType = currentType.ContainingType) + { + // If any type in the type hierarchy is not public, the type is not public. + // This makes sure to detect public types nested into eg. a private type. + if (currentType.DeclaredAccessibility is not Accessibility.Public) + { + return false; + } + } + + return true; + } + + /// + /// Checks whether a given symbol is an explicit interface implementation of a member of an internal interface (or more than one). + /// + /// The input member symbol to check. + /// Whether is an explicit interface implementation of internal interfaces. + public static bool IsExplicitInterfaceImplementationOfInternalInterfaces(this ISymbol symbol) + { + static bool IsAnyContainingTypePublic(IEnumerable symbols) + { + return symbols.Any(static symbol => symbol.ContainingType!.IsPubliclyAccessible()); + } + + return symbol switch + { + IMethodSymbol { ExplicitInterfaceImplementations: { Length: > 0 } methods } => !IsAnyContainingTypePublic(methods), + IPropertySymbol { ExplicitInterfaceImplementations: { Length: > 0 } properties } => !IsAnyContainingTypePublic(properties), + IEventSymbol { ExplicitInterfaceImplementations: { Length: > 0 } events } => !IsAnyContainingTypePublic(events), + _ => false + }; + } +} diff --git a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs index fcc2e7681..3a64b1e8b 100644 --- a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs +++ b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs @@ -1204,21 +1204,38 @@ Symbol GetType(string type, bool isGeneric = false, int genericIndex = -1, bool private IEnumerable GetInterfaces(INamedTypeSymbol symbol, bool includeInterfacesWithoutMappings = false) { - HashSet interfaces = new HashSet(); - foreach (var @interface in symbol.Interfaces) + HashSet interfaces = new(); + + // Gather all interfaces that are publicly accessible. We specifically need to exclude interfaces + // that are not public, as eg. those might be used for additional cloaked WinRT/COM interfaces. + // Ignoring them here makes sure that they're not processed to be part of the .winmd file. + void GatherPubliclyAccessibleInterfaces(ITypeSymbol symbol) { - interfaces.Add(@interface); - interfaces.UnionWith(@interface.AllInterfaces); + foreach (var @interface in symbol.Interfaces) + { + if (@interface.IsPubliclyAccessible()) + { + _ = interfaces.Add(@interface); + } + + // We're not using AllInterfaces on purpose: we only want to gather all interfaces but not + // from the base type. That's handled below to skip types that are already WinRT projections. + foreach (var @interface2 in @interface.AllInterfaces) + { + if (@interface2.IsPubliclyAccessible()) + { + _ = interfaces.Add(@interface2); + } + } + } } + GatherPubliclyAccessibleInterfaces(symbol); + var baseType = symbol.BaseType; while (baseType != null && !GeneratorHelper.IsWinRTType(baseType)) { - interfaces.UnionWith(baseType.Interfaces); - foreach (var @interface in baseType.Interfaces) - { - interfaces.UnionWith(@interface.AllInterfaces); - } + GatherPubliclyAccessibleInterfaces(baseType); baseType = baseType.BaseType; } @@ -1911,6 +1928,13 @@ void AddComponentType(INamedTypeSymbol type, Action visitTypeDeclaration = null) } else { + // Special case: skip members that are explicitly implementing internal interfaces. + // This allows implementing classic COM internal interfaces with non-WinRT signatures. + if (member.IsExplicitInterfaceImplementationOfInternalInterfaces()) + { + continue; + } + if (member is IMethodSymbol method && (method.MethodKind == MethodKind.Ordinary || method.MethodKind == MethodKind.ExplicitInterfaceImplementation || @@ -2683,12 +2707,19 @@ typeDeclaration.Node is INamedTypeSymbol symbol && } } - public bool IsPublic(ISymbol type) + public bool IsPublic(ISymbol symbol) { - return type.DeclaredAccessibility == Accessibility.Public || - type is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty || - type is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty || - type is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty; + // Check that the type has either public accessibility, or is an explicit interface implementation + if (symbol.DeclaredAccessibility == Accessibility.Public || + symbol is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty || + symbol is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty || + symbol is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty) + { + // If we have a containing type, we also check that it's publicly accessible + return symbol.ContainingType is not { } containingType || containingType.IsPubliclyAccessible(); + } + + return false; } public void GetNamespaceAndTypename(string qualifiedName, out string @namespace, out string typename) diff --git a/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest b/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest index 59cc17e2c..9a907d35f 100644 --- a/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest +++ b/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest @@ -74,5 +74,9 @@ name="AuthoringTest.TestClass" threadingModel="both" xmlns="urn:schemas-microsoft-com:winrt.v1" /> + \ No newline at end of file diff --git a/src/Tests/AuthoringConsumptionTest/pch.h b/src/Tests/AuthoringConsumptionTest/pch.h index 1eb7dade4..781afc1b9 100644 --- a/src/Tests/AuthoringConsumptionTest/pch.h +++ b/src/Tests/AuthoringConsumptionTest/pch.h @@ -4,6 +4,7 @@ // conflict with Storyboard::GetCurrentTime #undef GetCurrentTime +#include #include #include diff --git a/src/Tests/AuthoringConsumptionTest/test.cpp b/src/Tests/AuthoringConsumptionTest/test.cpp index d8bbf3c23..585078bc3 100644 --- a/src/Tests/AuthoringConsumptionTest/test.cpp +++ b/src/Tests/AuthoringConsumptionTest/test.cpp @@ -639,4 +639,38 @@ TEST(AuthoringTest, PartialClass) EXPECT_EQ(partialStruct.X, 3); EXPECT_EQ(partialStruct.Y, 4); EXPECT_EQ(partialStruct.Z, 5); +} + +TEST(AuthoringTest, MixedWinRTClassicCOM) +{ + TestMixedWinRTCOMWrapper wrapper; + + // Normal WinRT methods work as you'd expect + EXPECT_EQ(wrapper.HelloWorld(), L"Hello from mixed WinRT/COM"); + + // Verify we can grab the internal interface + IID internalInterface1Iid; + check_hresult(IIDFromString(L"{C7850559-8FF2-4E54-A237-6ED813F20CDC}", &internalInterface1Iid)); + winrt::com_ptr<::IUnknown> unknown1 = wrapper.as<::IUnknown>(); + winrt::com_ptr<::IUnknown> internalInterface1; + EXPECT_EQ(unknown1->QueryInterface(internalInterface1Iid, internalInterface1.put_void()), S_OK); + + // Verify we can grab the nested public interface (in an internal type) + IID internalInterface2Iid; + check_hresult(IIDFromString(L"{8A08E18A-8D20-4E7C-9242-857BFE1E3159}", &internalInterface2Iid)); + winrt::com_ptr<::IUnknown> unknown2 = wrapper.as<::IUnknown>(); + winrt::com_ptr<::IUnknown> internalInterface2; + EXPECT_EQ(unknown2->QueryInterface(internalInterface2Iid, internalInterface2.put_void()), S_OK); + + typedef int (__stdcall* GetNumber)(void*, int*); + + int number; + + // Validate the first call on IInternalInterface1 + EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface1.get()))[3])(internalInterface1.get(), &number), S_OK); + EXPECT_EQ(number, 42); + + // Validate the second call on IInternalInterface2 + EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface2.get()))[3])(internalInterface2.get(), &number), S_OK); + EXPECT_EQ(number, 123); } \ No newline at end of file diff --git a/src/Tests/AuthoringTest/Program.cs b/src/Tests/AuthoringTest/Program.cs index 2035d1b76..2b0d863d4 100644 --- a/src/Tests/AuthoringTest/Program.cs +++ b/src/Tests/AuthoringTest/Program.cs @@ -9,12 +9,17 @@ using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using System.Windows.Input; using Windows.Foundation; using Windows.Foundation.Collections; using Windows.Foundation.Metadata; +using Windows.Graphics.Effects; +using WinRT; +using WinRT.Interop; #pragma warning disable CA1416 @@ -1569,4 +1574,140 @@ public partial struct PartialStruct { public double Z; } + + public sealed class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2 + { + public string HelloWorld() + { + return "Hello from mixed WinRT/COM"; + } + + unsafe int IInternalInterface1.GetNumber(int* value) + { + *value = 42; + + return 0; + } + + unsafe int SomeInternalType.IInternalInterface2.GetNumber(int* value) + { + *value = 123; + + return 0; + } + } + + public interface IPublicInterface + { + string HelloWorld(); + } + + // Internal, classic COM interface + [global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")] + [WindowsRuntimeType] + [WindowsRuntimeHelperType(typeof(IInternalInterface1))] + internal unsafe interface IInternalInterface1 + { + int GetNumber(int* value); + + [global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")] + public struct Vftbl + { + public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl(); + + private static IntPtr InitVtbl() + { + Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl)); + + lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl; + lpVtbl->GetNumber = &GetNumberFromAbi; + + return (IntPtr)lpVtbl; + } + + private IUnknownVftbl IUnknownVftbl; + private delegate* unmanaged[Stdcall] GetNumber; + + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] + private static int GetNumberFromAbi(void* thisPtr, int* value) + { + try + { + return ComWrappersSupport.FindObject((IntPtr)thisPtr).GetNumber(value); + } + catch (Exception e) + { + ExceptionHelpers.SetErrorInfo(e); + + return Marshal.GetHRForException(e); + } + } + } + } + + internal struct SomeInternalType + { + // Nested, classic COM interface + [global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")] + [WindowsRuntimeType] + [WindowsRuntimeHelperType(typeof(IInternalInterface2))] + public unsafe interface IInternalInterface2 + { + int GetNumber(int* value); + + [global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")] + public struct Vftbl + { + public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl(); + + private static IntPtr InitVtbl() + { + Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl)); + + lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl; + lpVtbl->GetNumber = &GetNumberFromAbi; + + return (IntPtr)lpVtbl; + } + + private IUnknownVftbl IUnknownVftbl; + private delegate* unmanaged[Stdcall] GetNumber; + + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })] + private static int GetNumberFromAbi(void* thisPtr, int* value) + { + try + { + return ComWrappersSupport.FindObject((IntPtr)thisPtr).GetNumber(value); + } + catch (Exception e) + { + ExceptionHelpers.SetErrorInfo(e); + + return Marshal.GetHRForException(e); + } + } + } + } + } +} + +namespace ABI.AuthoringTest +{ + internal static class IInternalInterface1Methods + { + public static Guid IID => typeof(global::AuthoringTest.IInternalInterface1).GUID; + + public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.IInternalInterface1.Vftbl.AbiToProjectionVftablePtr; + } + + internal struct SomeInternalType + { + internal static class IInternalInterface2Methods + { + public static Guid IID => typeof(global::AuthoringTest.SomeInternalType.IInternalInterface2).GUID; + + public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.SomeInternalType.IInternalInterface2.Vftbl.AbiToProjectionVftablePtr; + } + } } \ No newline at end of file