diff --git a/Projections/Test/Test.csproj b/Projections/Test/Test.csproj index ebbe33d9e..6379d4188 100644 --- a/Projections/Test/Test.csproj +++ b/Projections/Test/Test.csproj @@ -7,7 +7,7 @@ - true + true true 8305;0618 diff --git a/TestComponentCSharp/Class.h b/TestComponentCSharp/Class.h index 3a8c5b3f4..fa6a0f21b 100644 --- a/TestComponentCSharp/Class.h +++ b/TestComponentCSharp/Class.h @@ -318,7 +318,11 @@ namespace winrt::TestComponentCSharp::implementation namespace winrt::TestComponentCSharp::factory_implementation { - struct Class : ClassT + struct Class : ClassT { + hstring ToString() + { + return L"Class"; + } }; } diff --git a/TestComponentCSharp/ComImports.h b/TestComponentCSharp/ComImports.h index 2876b7ddd..d1f0597d1 100644 --- a/TestComponentCSharp/ComImports.h +++ b/TestComponentCSharp/ComImports.h @@ -13,7 +13,11 @@ namespace winrt::TestComponentCSharp::implementation } namespace winrt::TestComponentCSharp::factory_implementation { - struct ComImports : ComImportsT + struct ComImports : ComImportsT { + hstring ToString() + { + return L"ComImports"; + } }; } diff --git a/TestComponentCSharp/NonAgileClass.h b/TestComponentCSharp/NonAgileClass.h index add4fbb69..188b35d7c 100644 --- a/TestComponentCSharp/NonAgileClass.h +++ b/TestComponentCSharp/NonAgileClass.h @@ -13,7 +13,11 @@ namespace winrt::TestComponentCSharp::implementation } namespace winrt::TestComponentCSharp::factory_implementation { - struct NonAgileClass : NonAgileClassT + struct NonAgileClass : NonAgileClassT { + hstring ToString() + { + return L"NonAgileClass"; + } }; } diff --git a/UnitTest/TestComponentCSharp_Tests.cs b/UnitTest/TestComponentCSharp_Tests.cs index fa848d423..d94bb3ff4 100644 --- a/UnitTest/TestComponentCSharp_Tests.cs +++ b/UnitTest/TestComponentCSharp_Tests.cs @@ -535,6 +535,54 @@ public void TestGenericCast() Assert.Equal(abiView.ThisPtr, abiView.As().As.Vftbl>().ThisPtr); } + [ComImport] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + [Guid("96369F54-8EB6-48F0-ABCE-C1B211E627C3")] + internal unsafe interface IStringableInterop + { + // Note: Invoking methods on ComInterfaceType.InterfaceIsIInspectable interfaces + // no longer appears supported in the runtime (probably with removal of WinRT support), + // so simulate with IUnknown. + void GetIids(out int iidCount, out IntPtr iids); + void GetRuntimeClassName(out IntPtr className); + void GetTrustLevel(out TrustLevel trustLevel); + + void ToString(out IntPtr hstr); + } + + [ComImport] + [Guid("39E050C3-4E74-441A-8DC0-B81104DF949C")] + // Using ComInterfaceType.InterfaceIsIInspectable here just to test the cast operation, + // not actually invoking RequestVerificationForWindowAsync. + [InterfaceType(ComInterfaceType.InterfaceIsIInspectable)] + public interface IUserConsentVerifierInterop + { + IAsyncOperation RequestVerificationForWindowAsync( + IntPtr appWindow, + out IntPtr message, + ref Guid riid); + } + + [Fact] + public unsafe void TestFactoryCast() + { + IntPtr hstr; + + // Access nonstatic class factory + var instanceFactory = Class.As(); + instanceFactory.ToString(out hstr); + Assert.Equal("Class", MarshalString.FromAbi(hstr)); + + // Access static class factory + var staticFactory = ComImports.As(); + staticFactory.ToString(out hstr); + Assert.Equal("ComImports", MarshalString.FromAbi(hstr)); + + // Test user class + var interop = Windows.Security.Credentials.UI.UserConsentVerifier.As(); + Assert.NotNull(interop); + } + [Fact] public void TestFundamentalGeneric() { diff --git a/WinRT.Runtime/CastExtensions.cs b/WinRT.Runtime/CastExtensions.cs index 5d5661164..eb918b41e 100644 --- a/WinRT.Runtime/CastExtensions.cs +++ b/WinRT.Runtime/CastExtensions.cs @@ -42,33 +42,7 @@ public static TInterface As(this object value) using (var objRef = GetRefForObject(value)) { - if (typeof(TInterface).GetCustomAttribute(typeof(System.Runtime.InteropServices.ComImportAttribute)) is object) - { - unsafe - { - static WinRT.Interop.IUnknownVftbl MarshalIUnknown(IntPtr thisPtr) - { - var vftblPtr = Unsafe.AsRef(thisPtr.ToPointer()); - var vftblIUnknown = Marshal.PtrToStructure(vftblPtr.Vftbl); - return vftblIUnknown; - } - - Guid iid = typeof(TInterface).GUID; - IntPtr comPtr; - MarshalIUnknown(objRef.ThisPtr).QueryInterface(objRef.ThisPtr, ref iid, out comPtr); - try - { - var obj = Marshal.GetObjectForIUnknown(comPtr); - return (TInterface)obj; - } - finally - { - MarshalIUnknown(comPtr).Release(comPtr); - } - } - } - - return (TInterface)typeof(TInterface).GetHelperType().GetConstructor(new[] { typeof(IObjectReference) }).Invoke(new object[] { objRef }); + return objRef.AsInterface(); } } diff --git a/WinRT.Runtime/ObjectReference.cs b/WinRT.Runtime/ObjectReference.cs index 295d0aaf6..ee81dab10 100644 --- a/WinRT.Runtime/ObjectReference.cs +++ b/WinRT.Runtime/ObjectReference.cs @@ -56,6 +56,27 @@ public unsafe ObjectReference As(Guid iid) return ObjectReference.Attach(ref thatPtr); } + public unsafe TInterface AsInterface() + { + if (typeof(TInterface).GetCustomAttribute(typeof(System.Runtime.InteropServices.ComImportAttribute)) is object) + { + Guid iid = typeof(TInterface).GUID; + Marshal.ThrowExceptionForHR(VftblIUnknown.QueryInterface(ThisPtr, ref iid, out IntPtr comPtr)); + try + { + return (TInterface)Marshal.GetObjectForIUnknown(comPtr); + } + finally + { + var vftblPtr = Unsafe.AsRef(comPtr.ToPointer()); + var vftblIUnknown = Marshal.PtrToStructure(vftblPtr.Vftbl); + vftblIUnknown.Release(comPtr); + } + } + + return (TInterface)typeof(TInterface).GetHelperType().GetConstructor(new[] { typeof(IObjectReference) }).Invoke(new object[] { this }); + } + public int TryAs(out ObjectReference objRef) => TryAs(GuidGenerator.GetIID(typeof(T)), out objRef); public virtual unsafe int TryAs(Guid iid, out ObjectReference objRef) diff --git a/cswinrt/code_writers.h b/cswinrt/code_writers.h index 5060c3f72..191c3c5c0 100644 --- a/cswinrt/code_writers.h +++ b/cswinrt/code_writers.h @@ -1033,19 +1033,10 @@ remove => %.% -= value; std::string write_static_cache_object(writer& w, std::string_view cache_type_name, TypeDef const& class_type) { - auto cache_interface = - w.write_temp( - R"((new BaseActivationFactory("%", "%.%"))._As)", - class_type.TypeNamespace(), - class_type.TypeNamespace(), - class_type.TypeName(), - class_type.TypeNamespace(), - cache_type_name); - w.write(R"( internal class _% : ABI.%.% { -public _%() : base(%()) { } +public _%() : base(%._factory._As()) { } private static WeakLazy<_%> _instance = new WeakLazy<_%>(); internal static % Instance => _instance.Value; } @@ -1054,7 +1045,9 @@ internal static % Instance => _instance.Value; class_type.TypeNamespace(), cache_type_name, cache_type_name, - cache_interface, + class_type.TypeName(), + class_type.TypeNamespace(), + cache_type_name, cache_type_name, cache_type_name, cache_type_name); @@ -1205,6 +1198,7 @@ MarshalInspectable.DisposeAbi(ptr); void write_attributed_types(writer& w, TypeDef const& type) { + bool factory_written{}; for (auto&& [interface_name, factory] : get_attributed_types(w, type)) { if (factory.activatable) @@ -1217,6 +1211,44 @@ MarshalInspectable.DisposeAbi(ptr); } else if (factory.statics) { + if (!factory_written) + { + factory_written = true; + + bool has_base_factory{}; + auto extends = type.Extends(); + while(!has_base_factory) + { + auto base_semantics = get_type_semantics(extends); + if (std::holds_alternative(base_semantics)) + { + break; + } + for_typedef(w, base_semantics, [&](auto base_type) + { + for (auto&& [_, base_factory] : get_attributed_types(w, base_type)) + { + if (base_factory.statics) + { + has_base_factory = true; + break; + } + } + extends = base_type.Extends(); + }); + } + + w.write(R"( +internal static %BaseActivationFactory _factory = new BaseActivationFactory("%", "%.%"); +public static %I As() => _factory.AsInterface(); +)", + has_base_factory ? "new " : "", + type.TypeNamespace(), + type.TypeNamespace(), + type.TypeName(), + has_base_factory ? "new " : ""); + } + write_static_members(w, factory.type, type); } } @@ -4257,6 +4289,7 @@ private % AsInternal(InterfaceTag<%> _) => _default; bind([&](writer& w) { bool has_base_type = !std::holds_alternative(get_type_semantics(type.Extends())); + if (!type.Flags().Sealed()) { w.write(R"( @@ -4274,7 +4307,6 @@ default_interface_abi_name, bind(type)); } - std::string_view access_spec = "protected "; std::string_view override_spec = has_base_type ? "override " : "virtual "; diff --git a/cswinrt/strings/WinRT.cs b/cswinrt/strings/WinRT.cs index 400591642..ce0d2840c 100644 --- a/cswinrt/strings/WinRT.cs +++ b/cswinrt/strings/WinRT.cs @@ -239,6 +239,10 @@ internal class BaseActivationFactory { private ObjectReference _IActivationFactory; + public ObjectReference Value { get => _IActivationFactory; } + + public I AsInterface() => _IActivationFactory.AsInterface(); + public BaseActivationFactory(string typeNamespace, string typeFullName) { var runtimeClassId = typeFullName.Replace("WinRT", "Windows"); @@ -289,6 +293,7 @@ internal class ActivationFactory : BaseActivationFactory public ActivationFactory() : base(typeof(T).Namespace, typeof(T).FullName) { } static WeakLazy> _factory = new WeakLazy>(); + public new static I AsInterface() => _factory.Value.Value.AsInterface(); public static ObjectReference As() => _factory.Value._As(); public static ObjectReference ActivateInstance() => _factory.Value._ActivateInstance(); }