Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Unwrap flag from UniqueComInterfaceMarshaller #92599

Merged
merged 7 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -948,33 +948,6 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
out IntPtr inner);

using ComHolder releaseIdentity = new ComHolder(identity);
if (flags.HasFlag(CreateObjectFlags.Unwrap))
{
ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity);
if (comInterfaceDispatch != null)
{
// If we found a managed object wrapper in this ComWrappers instance
// and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for,
// unwrap it. We don't AddRef the wrapper as we don't take a reference to it.
//
// A managed object can have multiple managed object wrappers, with a max of one per context.
// Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the
// managed object wrappers for A created with C1 and C2 respectively.
// If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance,
// we will create a new wrapper. In this scenario, we'll only unwrap B2.
object unwrapped = ComInterfaceDispatch.GetInstance<object>(comInterfaceDispatch);
if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext))
{
// The unwrapped object has a CCW in this context. Compare with identity
// so we can see if it's the CCW for the unwrapped object in this context.
if (unwrappedWrapperInThisContext.ComIp == identity)
{
retValue = unwrapped;
return true;
}
}
}
}

if (!flags.HasFlag(CreateObjectFlags.UniqueInstance))
{
Expand Down Expand Up @@ -1018,6 +991,33 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
return true;
}
}
if (flags.HasFlag(CreateObjectFlags.Unwrap))
{
ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(identity);
if (comInterfaceDispatch != null)
{
// If we found a managed object wrapper in this ComWrappers instance
// and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for,
// unwrap it. We don't AddRef the wrapper as we don't take a reference to it.
//
// A managed object can have multiple managed object wrappers, with a max of one per context.
// Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the
// managed object wrappers for A created with C1 and C2 respectively.
// If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance,
// we will create a new wrapper. In this scenario, we'll only unwrap B2.
object unwrapped = ComInterfaceDispatch.GetInstance<object>(comInterfaceDispatch);
if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext))
{
// The unwrapped object has a CCW in this context. Compare with identity
// so we can see if it's the CCW for the unwrapped object in this context.
if (unwrappedWrapperInThisContext.ComIp == identity)
{
retValue = unwrapped;
return true;
}
}
}
}
}

retValue = CreateObject(identity, flags);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static unsafe class UniqueComInterfaceMarshaller<T>
{
return default;
}
return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.Unwrap | CreateObjectFlags.UniqueInstance);
return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.UniqueInstance);
}


Expand Down
6 changes: 6 additions & 0 deletions src/tests/Interop/COM/ComWrappers/API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace ComWrappersTests
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

using ComWrappersTests.Common;
using TestLibrary;
Expand Down Expand Up @@ -188,6 +189,11 @@ public void ValidateComInterfaceCreationRoundTrip()
var testObjUnwrapped = wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap);
Assert.Same(testObj, testObjUnwrapped);

// UniqueInstance and Unwrap should always be a new com object, never unwrapped
var testObjUniqueUnwrapped = (ITestObjectWrapper)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap | CreateObjectFlags.UniqueInstance);
Assert.NotSame(testObj, testObjUniqueUnwrapped);
testObjUniqueUnwrapped.FinalRelease();

// Release the wrapper
int count = Marshal.Release(comWrapper);
Assert.Equal(0, count);
Expand Down
13 changes: 12 additions & 1 deletion src/tests/Interop/COM/ComWrappers/Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
namespace ComWrappersTests.Common
{
using System;
using System.Diagnostics;
using System.Threading;
using System.Runtime.InteropServices;

Expand Down Expand Up @@ -97,18 +98,28 @@ public class ITestObjectWrapper : ITest
{
private readonly ITestVtbl._SetValue _setValue;
private readonly IntPtr _ptr;
private bool _released;

public ITestObjectWrapper(IntPtr ptr)
{
_ptr = ptr;
VtblPtr inst = Marshal.PtrToStructure<VtblPtr>(ptr);
ITestVtbl _vtbl = Marshal.PtrToStructure<ITestVtbl>(inst.Vtbl);
_setValue = Marshal.GetDelegateForFunctionPointer<ITestVtbl._SetValue>(_vtbl.SetValue);
_released = false;
}

public int FinalRelease()
{
Debug.Assert(!_released);
int count = Marshal.Release(_ptr);
_released = true;
return count;
}

~ITestObjectWrapper()
{
if (_ptr != IntPtr.Zero)
if (_ptr != IntPtr.Zero && !_released)
{
Marshal.Release(_ptr);
}
Expand Down