From 7a5278c1b51e3afb45746e51bf5858a53b15d8f7 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Thu, 7 May 2020 12:55:08 -0700 Subject: [PATCH 1/2] Make global ComWrappers for marshalling respect requested IID --- src/coreclr/src/vm/interopconverter.cpp | 20 ++++++++++++++-- .../src/Interop/COM/ComWrappers/Common.cs | 1 + .../GlobalInstance/GlobalInstance.cs | 23 +++++++++++-------- .../ReferenceTrackerRuntime.cpp | 7 +++++- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 64312e1438214..ec331be25758c 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -104,8 +104,16 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { - pUnk.SuppressRelease(); - RETURN pUnk; + GUID iid; + pMT->GetGuid(&iid, /*bGenerateIfNotFound*/ FALSE, /*bClassic*/ FALSE); + + SafeComHolder pvObj; + hr = pUnk->QueryInterface(iid, &pvObj); + if (FAILED(hr)) + COMPlusThrowHR(hr); + + pvObj.SuppressRelease(); + RETURN pvObj; } SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -194,6 +202,12 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (pFetchedIpType != NULL) *pFetchedIpType = ReqIpType; + if (pvObj != NULL) + { + pUnk->Release(); + pUnk = pvObj.Extract(); + } + RETURN pUnk; } @@ -466,9 +480,11 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI { SafeComHolder pvObj; hr = pUnk->QueryInterface(iid, &pvObj); + pUnk->Release(); if (FAILED(hr)) COMPlusThrowHR(hr); + pUnk = pvObj.Extract(); RETURN pUnk; } diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs index 2f322018f0f80..0e75049200c9b 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs @@ -11,6 +11,7 @@ namespace ComWrappersTests.Common // Managed object with native wrapper definition. // [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] interface ITest { void SetValue(int i); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs index ef9870b08672d..b362bc641b037 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs @@ -48,9 +48,9 @@ extern public static int UpdateTestObjectAsIInspectable( [DllImport(nameof(MockReferenceTrackerRuntime))] extern public static int UpdateTestObjectAsInterface( - [MarshalAs(UnmanagedType.Interface)] Test testObj, + [MarshalAs(UnmanagedType.Interface)] ITest testObj, int i, - [Out, MarshalAs(UnmanagedType.Interface)] out Test ret); + [Out, MarshalAs(UnmanagedType.Interface)] out ITest ret); } private const string ManagedServerTypeName = "ConsumeNETServerTesting"; @@ -298,6 +298,11 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj); Assert.AreNotEqual(IntPtr.Zero, dispatchWrapper); Assert.AreEqual(dispatchObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject..."); + IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj); + Assert.AreNotEqual(IntPtr.Zero, unknownWrapper); + Assert.AreNotEqual(unknownWrapper, dispatchWrapper); } Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown..."); @@ -326,33 +331,33 @@ private static void ValidatePInvokes(bool validateUseRegistered) GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; Console.WriteLine($" -- Validate MarshalAs IUnknown..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, validateUseRegistered); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, shouldSucceed: validateUseRegistered); object obj = MarshalInterface.CreateTrackerObjectAsIUnknown(); Assert.AreEqual(validateUseRegistered, obj is FakeWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); if (validateUseRegistered) { Console.WriteLine($" -- Validate MarshalAs IDispatch..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, validateUseRegistered, new TestEx(IID_IDISPATCH)); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, shouldSucceed: true, new TestEx(IID_IDISPATCH)); Console.WriteLine($" -- Validate MarshalAs IInspectable..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, validateUseRegistered, new TestEx(IID_IINSPECTABLE)); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, shouldSucceed: true, new TestEx(IID_IINSPECTABLE)); } Console.WriteLine($" -- Validate MarshalAs Interface..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, validateUseRegistered); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, shouldSucceed: true); if (validateUseRegistered) { Assert.Throws(() => MarshalInterface.CreateTrackerObjectWrongType()); FakeWrapper wrapper = MarshalInterface.CreateTrackerObjectAsInterface(); - Assert.IsNotNull(obj, $"Should have returned {nameof(FakeWrapper)} instance"); + Assert.IsNotNull(wrapper, $"Should have returned {nameof(FakeWrapper)} instance"); } } private delegate int UpdateTestObject(T testObj, int i, out T ret) where T : class; - private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool validateUseRegistered, Test testObj = null) where T : class + private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool shouldSucceed, Test testObj = null) where T : class { const int E_NOINTERFACE = unchecked((int)0x80004002); int value = 10; @@ -363,7 +368,7 @@ private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool T retObj; int hr = func(testObj as T, value, out retObj); Assert.AreEqual(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); - if (validateUseRegistered) + if (shouldSucceed) { Assert.IsTrue(retObj is Test); Assert.AreEqual(value, testObj.GetValue()); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index 867efa01866e4..54cf0bb31f14b 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -374,5 +374,10 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsInterface(ITest *o if (obj == nullptr) return E_POINTER; - return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); + HRESULT hr; + RETURN_IF_FAILED(obj->SetValue(i)); + + obj->AddRef(); + *out = obj; + return S_OK; } From 67631abda4b2c4771d33ba1c67739ee694072d26 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Fri, 8 May 2020 12:55:06 -0700 Subject: [PATCH 2/2] PR feedback --- src/coreclr/src/vm/interopconverter.cpp | 33 +++++++++++-------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index ec331be25758c..ee632b4ada792 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -107,12 +107,11 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri GUID iid; pMT->GetGuid(&iid, /*bGenerateIfNotFound*/ FALSE, /*bClassic*/ FALSE); - SafeComHolder pvObj; - hr = pUnk->QueryInterface(iid, &pvObj); + IUnknown* pvObj; + hr = SafeQueryInterface(pUnk, iid, &pvObj); if (FAILED(hr)) COMPlusThrowHR(hr); - pvObj.SuppressRelease(); RETURN pvObj; } @@ -185,15 +184,20 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType { hr = S_OK; - SafeComHolder pvObj; + IUnknown* pvObj; if (ReqIpType & ComIpType_Dispatch) { - hr = pUnk->QueryInterface(IID_IDispatch, &pvObj); + hr = SafeQueryInterface(pUnk, IID_IDispatch, &pvObj); + pUnk->Release(); } else if (ReqIpType & ComIpType_Inspectable) { - SafeComHolder pvObj; - hr = pUnk->QueryInterface(IID_IInspectable, &pvObj); + hr = SafeQueryInterface(pUnk, IID_IInspectable, &pvObj); + pUnk->Release(); + } + else + { + pvObj = pUnk; } if (FAILED(hr)) @@ -202,13 +206,7 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (pFetchedIpType != NULL) *pFetchedIpType = ReqIpType; - if (pvObj != NULL) - { - pUnk->Release(); - pUnk = pvObj.Extract(); - } - - RETURN pUnk; + RETURN pvObj; } MethodTable *pMT = (*poref)->GetMethodTable(); @@ -478,14 +476,13 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { - SafeComHolder pvObj; - hr = pUnk->QueryInterface(iid, &pvObj); + IUnknown* pvObj; + hr = SafeQueryInterface(pUnk, iid, &pvObj); pUnk->Release(); if (FAILED(hr)) COMPlusThrowHR(hr); - pUnk = pvObj.Extract(); - RETURN pUnk; + RETURN pvObj; } MethodTable *pMT = (*poref)->GetMethodTable();