diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index 71946fb85dc411..5ad028deda1182 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -948,33 +948,6 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( out IntPtr inner); using ComHolder releaseIdentity = new ComHolder(identity); - if (flags.HasFlag(CreateObjectFlags.Unwrap)) - { - ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity); - if (comInterfaceDispatch != null) - { - // If we found a managed object wrapper in this ComWrappers instance - // and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for, - // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. - // - // A managed object can have multiple managed object wrappers, with a max of one per context. - // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the - // managed object wrappers for A created with C1 and C2 respectively. - // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, - // we will create a new wrapper. In this scenario, we'll only unwrap B2. - object unwrapped = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); - if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext)) - { - // The unwrapped object has a CCW in this context. Compare with identity - // so we can see if it's the CCW for the unwrapped object in this context. - if (unwrappedWrapperInThisContext.ComIp == identity) - { - retValue = unwrapped; - return true; - } - } - } - } if (!flags.HasFlag(CreateObjectFlags.UniqueInstance)) { @@ -1018,6 +991,33 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( return true; } } + if (flags.HasFlag(CreateObjectFlags.Unwrap)) + { + ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity); + if (comInterfaceDispatch != null) + { + // If we found a managed object wrapper in this ComWrappers instance + // and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for, + // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. + // + // A managed object can have multiple managed object wrappers, with a max of one per context. + // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the + // managed object wrappers for A created with C1 and C2 respectively. + // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, + // we will create a new wrapper. In this scenario, we'll only unwrap B2. + object unwrapped = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); + if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext)) + { + // The unwrapped object has a CCW in this context. Compare with identity + // so we can see if it's the CCW for the unwrapped object in this context. + if (unwrappedWrapperInThisContext.ComIp == identity) + { + retValue = unwrapped; + return true; + } + } + } + } } retValue = CreateObject(identity, flags); diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs index 3608cbd18301c4..59b965dd1d11f1 100644 --- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs @@ -55,7 +55,7 @@ public static unsafe class UniqueComInterfaceMarshaller { return default; } - return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.Unwrap | CreateObjectFlags.UniqueInstance); + return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.UniqueInstance); } diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index 5f987f567fe3cd..05660383538e37 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -9,6 +9,7 @@ namespace ComWrappersTests using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; using ComWrappersTests.Common; using TestLibrary; @@ -188,6 +189,11 @@ public void ValidateComInterfaceCreationRoundTrip() var testObjUnwrapped = wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap); Assert.Same(testObj, testObjUnwrapped); + // UniqueInstance and Unwrap should always be a new com object, never unwrapped + var testObjUniqueUnwrapped = (ITestObjectWrapper)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap | CreateObjectFlags.UniqueInstance); + Assert.NotSame(testObj, testObjUniqueUnwrapped); + testObjUniqueUnwrapped.FinalRelease(); + // Release the wrapper int count = Marshal.Release(comWrapper); Assert.Equal(0, count); diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index fb99fb15bc33ad..ed62825070fd62 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -4,6 +4,7 @@ namespace ComWrappersTests.Common { using System; + using System.Diagnostics; using System.Threading; using System.Runtime.InteropServices; @@ -97,6 +98,7 @@ public class ITestObjectWrapper : ITest { private readonly ITestVtbl._SetValue _setValue; private readonly IntPtr _ptr; + private bool _released; public ITestObjectWrapper(IntPtr ptr) { @@ -104,11 +106,20 @@ public ITestObjectWrapper(IntPtr ptr) VtblPtr inst = Marshal.PtrToStructure(ptr); ITestVtbl _vtbl = Marshal.PtrToStructure(inst.Vtbl); _setValue = Marshal.GetDelegateForFunctionPointer(_vtbl.SetValue); + _released = false; + } + + public int FinalRelease() + { + Debug.Assert(!_released); + int count = Marshal.Release(_ptr); + _released = true; + return count; } ~ITestObjectWrapper() { - if (_ptr != IntPtr.Zero) + if (_ptr != IntPtr.Zero && !_released) { Marshal.Release(_ptr); }