diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 64312e1438214..ee632b4ada792 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -104,8 +104,15 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { - pUnk.SuppressRelease(); - RETURN pUnk; + GUID iid; + pMT->GetGuid(&iid, /*bGenerateIfNotFound*/ FALSE, /*bClassic*/ FALSE); + + IUnknown* pvObj; + hr = SafeQueryInterface(pUnk, iid, &pvObj); + if (FAILED(hr)) + COMPlusThrowHR(hr); + + RETURN pvObj; } SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -177,15 +184,20 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType { hr = S_OK; - SafeComHolder pvObj; + IUnknown* pvObj; if (ReqIpType & ComIpType_Dispatch) { - hr = pUnk->QueryInterface(IID_IDispatch, &pvObj); + hr = SafeQueryInterface(pUnk, IID_IDispatch, &pvObj); + pUnk->Release(); } else if (ReqIpType & ComIpType_Inspectable) { - SafeComHolder pvObj; - hr = pUnk->QueryInterface(IID_IInspectable, &pvObj); + hr = SafeQueryInterface(pUnk, IID_IInspectable, &pvObj); + pUnk->Release(); + } + else + { + pvObj = pUnk; } if (FAILED(hr)) @@ -194,7 +206,7 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (pFetchedIpType != NULL) *pFetchedIpType = ReqIpType; - RETURN pUnk; + RETURN pvObj; } MethodTable *pMT = (*poref)->GetMethodTable(); @@ -464,12 +476,13 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) { - SafeComHolder pvObj; - hr = pUnk->QueryInterface(iid, &pvObj); + IUnknown* pvObj; + hr = SafeQueryInterface(pUnk, iid, &pvObj); + pUnk->Release(); if (FAILED(hr)) COMPlusThrowHR(hr); - RETURN pUnk; + RETURN pvObj; } MethodTable *pMT = (*poref)->GetMethodTable(); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs index 2f322018f0f80..0e75049200c9b 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs @@ -11,6 +11,7 @@ namespace ComWrappersTests.Common // Managed object with native wrapper definition. // [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] interface ITest { void SetValue(int i); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs index ef9870b08672d..b362bc641b037 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs @@ -48,9 +48,9 @@ extern public static int UpdateTestObjectAsIInspectable( [DllImport(nameof(MockReferenceTrackerRuntime))] extern public static int UpdateTestObjectAsInterface( - [MarshalAs(UnmanagedType.Interface)] Test testObj, + [MarshalAs(UnmanagedType.Interface)] ITest testObj, int i, - [Out, MarshalAs(UnmanagedType.Interface)] out Test ret); + [Out, MarshalAs(UnmanagedType.Interface)] out ITest ret); } private const string ManagedServerTypeName = "ConsumeNETServerTesting"; @@ -298,6 +298,11 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered) IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj); Assert.AreNotEqual(IntPtr.Zero, dispatchWrapper); Assert.AreEqual(dispatchObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject..."); + IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj); + Assert.AreNotEqual(IntPtr.Zero, unknownWrapper); + Assert.AreNotEqual(unknownWrapper, dispatchWrapper); } Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown..."); @@ -326,33 +331,33 @@ private static void ValidatePInvokes(bool validateUseRegistered) GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; Console.WriteLine($" -- Validate MarshalAs IUnknown..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, validateUseRegistered); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, shouldSucceed: validateUseRegistered); object obj = MarshalInterface.CreateTrackerObjectAsIUnknown(); Assert.AreEqual(validateUseRegistered, obj is FakeWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); if (validateUseRegistered) { Console.WriteLine($" -- Validate MarshalAs IDispatch..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, validateUseRegistered, new TestEx(IID_IDISPATCH)); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, shouldSucceed: true, new TestEx(IID_IDISPATCH)); Console.WriteLine($" -- Validate MarshalAs IInspectable..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, validateUseRegistered, new TestEx(IID_IINSPECTABLE)); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, shouldSucceed: true, new TestEx(IID_IINSPECTABLE)); } Console.WriteLine($" -- Validate MarshalAs Interface..."); - ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, validateUseRegistered); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, shouldSucceed: true); if (validateUseRegistered) { Assert.Throws(() => MarshalInterface.CreateTrackerObjectWrongType()); FakeWrapper wrapper = MarshalInterface.CreateTrackerObjectAsInterface(); - Assert.IsNotNull(obj, $"Should have returned {nameof(FakeWrapper)} instance"); + Assert.IsNotNull(wrapper, $"Should have returned {nameof(FakeWrapper)} instance"); } } private delegate int UpdateTestObject(T testObj, int i, out T ret) where T : class; - private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool validateUseRegistered, Test testObj = null) where T : class + private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool shouldSucceed, Test testObj = null) where T : class { const int E_NOINTERFACE = unchecked((int)0x80004002); int value = 10; @@ -363,7 +368,7 @@ private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool T retObj; int hr = func(testObj as T, value, out retObj); Assert.AreEqual(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); - if (validateUseRegistered) + if (shouldSucceed) { Assert.IsTrue(retObj is Test); Assert.AreEqual(value, testObj.GetValue()); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index 867efa01866e4..54cf0bb31f14b 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -374,5 +374,10 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsInterface(ITest *o if (obj == nullptr) return E_POINTER; - return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); + HRESULT hr; + RETURN_IF_FAILED(obj->SetValue(i)); + + obj->AddRef(); + *out = obj; + return S_OK; }