Skip to content

Commit

Permalink
Skip interfaces not publicly accessible in authoring scenarios (#1394)
Browse files Browse the repository at this point in the history
* Add internal COM interfaces to authoring test

* Add MixedWinRTClassicCOM authoring tests

* Skip interface types not publicly accessible

* Minor code refactoring

* Suppress diagnostics for not publicly accessible types

* Use fully qualified name for [Guid] to avoid conflicts

* Remmove collection expressions in projection attributes

* Skip processing explicit members of internal interfaces

* Skip processing symbols nested in internal types

* Add ABI types for AOT generator

* Restore original order/filtering to gather interfaces

* Fix build errors in AuthoringConsumptionTest

* Add TestMixedWinRTCOMWrapper to activation manifest

* Fix unit test

* Fix typos in ABI method names
  • Loading branch information
Sergio0694 authored Nov 22, 2023
1 parent b54a421 commit 38a732b
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 16 deletions.
6 changes: 4 additions & 2 deletions src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ private void CheckDeclarations()
foreach (var declaration in syntaxReceiver.Declarations)
{
var model = _context.Compilation.GetSemanticModel(declaration.SyntaxTree);
var symbol = model.GetDeclaredSymbol(declaration);

// Check symbol information for whether it is public to properly detect partial types
// which can leave out modifier.
if (model.GetDeclaredSymbol(declaration).DeclaredAccessibility != Accessibility.Public)
// which can leave out modifier. Also ignore nested types not effectively public
if (symbol.DeclaredAccessibility != Accessibility.Public ||
(symbol is ITypeSymbol typeSymbol && !typeSymbol.IsPubliclyAccessible()))
{
continue;
}
Expand Down
54 changes: 54 additions & 0 deletions src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;

#nullable enable

namespace Generator;

/// <summary>
/// Extensions for symbol types.
/// </summary>
internal static class SymbolExtensions
{
/// <summary>
/// Checks whether a given type symbol is publicly accessible (ie. it's public and not nested in any non public type).
/// </summary>
/// <param name="type">The type symbol to check for public accessibility.</param>
/// <returns>Whether <paramref name="type"/> is publicly accessible.</returns>
public static bool IsPubliclyAccessible(this ITypeSymbol type)
{
for (ITypeSymbol? currentType = type; currentType is not null; currentType = currentType.ContainingType)
{
// If any type in the type hierarchy is not public, the type is not public.
// This makes sure to detect public types nested into eg. a private type.
if (currentType.DeclaredAccessibility is not Accessibility.Public)
{
return false;
}
}

return true;
}

/// <summary>
/// Checks whether a given symbol is an explicit interface implementation of a member of an internal interface (or more than one).
/// </summary>
/// <param name="symbol">The input member symbol to check.</param>
/// <returns>Whether <paramref name="symbol"/> is an explicit interface implementation of internal interfaces.</returns>
public static bool IsExplicitInterfaceImplementationOfInternalInterfaces(this ISymbol symbol)
{
static bool IsAnyContainingTypePublic(IEnumerable<ISymbol> symbols)
{
return symbols.Any(static symbol => symbol.ContainingType!.IsPubliclyAccessible());
}

return symbol switch
{
IMethodSymbol { ExplicitInterfaceImplementations: { Length: > 0 } methods } => !IsAnyContainingTypePublic(methods),
IPropertySymbol { ExplicitInterfaceImplementations: { Length: > 0 } properties } => !IsAnyContainingTypePublic(properties),
IEventSymbol { ExplicitInterfaceImplementations: { Length: > 0 } events } => !IsAnyContainingTypePublic(events),
_ => false
};
}
}
59 changes: 45 additions & 14 deletions src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1204,21 +1204,38 @@ Symbol GetType(string type, bool isGeneric = false, int genericIndex = -1, bool

private IEnumerable<INamedTypeSymbol> GetInterfaces(INamedTypeSymbol symbol, bool includeInterfacesWithoutMappings = false)
{
HashSet<INamedTypeSymbol> interfaces = new HashSet<INamedTypeSymbol>();
foreach (var @interface in symbol.Interfaces)
HashSet<INamedTypeSymbol> interfaces = new();

// Gather all interfaces that are publicly accessible. We specifically need to exclude interfaces
// that are not public, as eg. those might be used for additional cloaked WinRT/COM interfaces.
// Ignoring them here makes sure that they're not processed to be part of the .winmd file.
void GatherPubliclyAccessibleInterfaces(ITypeSymbol symbol)
{
interfaces.Add(@interface);
interfaces.UnionWith(@interface.AllInterfaces);
foreach (var @interface in symbol.Interfaces)
{
if (@interface.IsPubliclyAccessible())
{
_ = interfaces.Add(@interface);
}

// We're not using AllInterfaces on purpose: we only want to gather all interfaces but not
// from the base type. That's handled below to skip types that are already WinRT projections.
foreach (var @interface2 in @interface.AllInterfaces)
{
if (@interface2.IsPubliclyAccessible())
{
_ = interfaces.Add(@interface2);
}
}
}
}

GatherPubliclyAccessibleInterfaces(symbol);

var baseType = symbol.BaseType;
while (baseType != null && !GeneratorHelper.IsWinRTType(baseType))
{
interfaces.UnionWith(baseType.Interfaces);
foreach (var @interface in baseType.Interfaces)
{
interfaces.UnionWith(@interface.AllInterfaces);
}
GatherPubliclyAccessibleInterfaces(baseType);

baseType = baseType.BaseType;
}
Expand Down Expand Up @@ -1911,6 +1928,13 @@ void AddComponentType(INamedTypeSymbol type, Action visitTypeDeclaration = null)
}
else
{
// Special case: skip members that are explicitly implementing internal interfaces.
// This allows implementing classic COM internal interfaces with non-WinRT signatures.
if (member.IsExplicitInterfaceImplementationOfInternalInterfaces())
{
continue;
}

if (member is IMethodSymbol method &&
(method.MethodKind == MethodKind.Ordinary ||
method.MethodKind == MethodKind.ExplicitInterfaceImplementation ||
Expand Down Expand Up @@ -2683,12 +2707,19 @@ typeDeclaration.Node is INamedTypeSymbol symbol &&
}
}

public bool IsPublic(ISymbol type)
public bool IsPublic(ISymbol symbol)
{
return type.DeclaredAccessibility == Accessibility.Public ||
type is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
type is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
type is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty;
// Check that the type has either public accessibility, or is an explicit interface implementation
if (symbol.DeclaredAccessibility == Accessibility.Public ||
symbol is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
symbol is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
symbol is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
{
// If we have a containing type, we also check that it's publicly accessible
return symbol.ContainingType is not { } containingType || containingType.IsPubliclyAccessible();
}

return false;
}

public void GetNamespaceAndTypename(string qualifiedName, out string @namespace, out string typename)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,9 @@
name="AuthoringTest.TestClass"
threadingModel="both"
xmlns="urn:schemas-microsoft-com:winrt.v1" />
<activatableClass
name="AuthoringTest.TestMixedWinRTCOMWrapper"
threadingModel="both"
xmlns="urn:schemas-microsoft-com:winrt.v1" />
</file>
</assembly>
1 change: 1 addition & 0 deletions src/Tests/AuthoringConsumptionTest/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// conflict with Storyboard::GetCurrentTime
#undef GetCurrentTime

#include <Windows.h>
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.Foundation.Collections.h>

Expand Down
34 changes: 34 additions & 0 deletions src/Tests/AuthoringConsumptionTest/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,38 @@ TEST(AuthoringTest, PartialClass)
EXPECT_EQ(partialStruct.X, 3);
EXPECT_EQ(partialStruct.Y, 4);
EXPECT_EQ(partialStruct.Z, 5);
}

TEST(AuthoringTest, MixedWinRTClassicCOM)
{
TestMixedWinRTCOMWrapper wrapper;

// Normal WinRT methods work as you'd expect
EXPECT_EQ(wrapper.HelloWorld(), L"Hello from mixed WinRT/COM");

// Verify we can grab the internal interface
IID internalInterface1Iid;
check_hresult(IIDFromString(L"{C7850559-8FF2-4E54-A237-6ED813F20CDC}", &internalInterface1Iid));
winrt::com_ptr<::IUnknown> unknown1 = wrapper.as<::IUnknown>();
winrt::com_ptr<::IUnknown> internalInterface1;
EXPECT_EQ(unknown1->QueryInterface(internalInterface1Iid, internalInterface1.put_void()), S_OK);

// Verify we can grab the nested public interface (in an internal type)
IID internalInterface2Iid;
check_hresult(IIDFromString(L"{8A08E18A-8D20-4E7C-9242-857BFE1E3159}", &internalInterface2Iid));
winrt::com_ptr<::IUnknown> unknown2 = wrapper.as<::IUnknown>();
winrt::com_ptr<::IUnknown> internalInterface2;
EXPECT_EQ(unknown2->QueryInterface(internalInterface2Iid, internalInterface2.put_void()), S_OK);

typedef int (__stdcall* GetNumber)(void*, int*);

int number;

// Validate the first call on IInternalInterface1
EXPECT_EQ(reinterpret_cast<GetNumber>((*reinterpret_cast<void***>(internalInterface1.get()))[3])(internalInterface1.get(), &number), S_OK);
EXPECT_EQ(number, 42);

// Validate the second call on IInternalInterface2
EXPECT_EQ(reinterpret_cast<GetNumber>((*reinterpret_cast<void***>(internalInterface2.get()))[3])(internalInterface2.get(), &number), S_OK);
EXPECT_EQ(number, 123);
}
141 changes: 141 additions & 0 deletions src/Tests/AuthoringTest/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Input;
using Windows.Foundation;
using Windows.Foundation.Collections;
using Windows.Foundation.Metadata;
using Windows.Graphics.Effects;
using WinRT;
using WinRT.Interop;

#pragma warning disable CA1416

Expand Down Expand Up @@ -1569,4 +1574,140 @@ public partial struct PartialStruct
{
public double Z;
}

public sealed class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2
{
public string HelloWorld()
{
return "Hello from mixed WinRT/COM";
}

unsafe int IInternalInterface1.GetNumber(int* value)
{
*value = 42;

return 0;
}

unsafe int SomeInternalType.IInternalInterface2.GetNumber(int* value)
{
*value = 123;

return 0;
}
}

public interface IPublicInterface
{
string HelloWorld();
}

// Internal, classic COM interface
[global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
[WindowsRuntimeType]
[WindowsRuntimeHelperType(typeof(IInternalInterface1))]
internal unsafe interface IInternalInterface1
{
int GetNumber(int* value);

[global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
public struct Vftbl
{
public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();

private static IntPtr InitVtbl()
{
Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));

lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
lpVtbl->GetNumber = &GetNumberFromAbi;

return (IntPtr)lpVtbl;
}

private IUnknownVftbl IUnknownVftbl;
private delegate* unmanaged[Stdcall]<void*, int*, int> GetNumber;

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static int GetNumberFromAbi(void* thisPtr, int* value)
{
try
{
return ComWrappersSupport.FindObject<IInternalInterface1>((IntPtr)thisPtr).GetNumber(value);
}
catch (Exception e)
{
ExceptionHelpers.SetErrorInfo(e);

return Marshal.GetHRForException(e);
}
}
}
}

internal struct SomeInternalType
{
// Nested, classic COM interface
[global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
[WindowsRuntimeType]
[WindowsRuntimeHelperType(typeof(IInternalInterface2))]
public unsafe interface IInternalInterface2
{
int GetNumber(int* value);

[global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
public struct Vftbl
{
public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();

private static IntPtr InitVtbl()
{
Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));

lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
lpVtbl->GetNumber = &GetNumberFromAbi;

return (IntPtr)lpVtbl;
}

private IUnknownVftbl IUnknownVftbl;
private delegate* unmanaged[Stdcall]<void*, int*, int> GetNumber;

[UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
private static int GetNumberFromAbi(void* thisPtr, int* value)
{
try
{
return ComWrappersSupport.FindObject<IInternalInterface2>((IntPtr)thisPtr).GetNumber(value);
}
catch (Exception e)
{
ExceptionHelpers.SetErrorInfo(e);

return Marshal.GetHRForException(e);
}
}
}
}
}
}

namespace ABI.AuthoringTest
{
internal static class IInternalInterface1Methods
{
public static Guid IID => typeof(global::AuthoringTest.IInternalInterface1).GUID;

public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.IInternalInterface1.Vftbl.AbiToProjectionVftablePtr;
}

internal struct SomeInternalType
{
internal static class IInternalInterface2Methods
{
public static Guid IID => typeof(global::AuthoringTest.SomeInternalType.IInternalInterface2).GUID;

public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.SomeInternalType.IInternalInterface2.Vftbl.AbiToProjectionVftablePtr;
}
}
}

0 comments on commit 38a732b

Please sign in to comment.