Skip to content

Commit

Permalink
ComWrappers tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronRobinsonMSFT committed Dec 5, 2020
1 parent ed9d7c2 commit 7221c10
Show file tree
Hide file tree
Showing 4 changed files with 657 additions and 103 deletions.
126 changes: 99 additions & 27 deletions src/tests/Interop/COM/ComWrappers/API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,34 @@ class TestComWrappers : ComWrappers
{
protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
Assert.IsTrue(obj is Test);

IntPtr fpQueryInteface = default;
IntPtr fpAddRef = default;
IntPtr fpRelease = default;
ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease);

var vtbl = new ITestVtbl()
ComInterfaceEntry* entryRaw = null;
count = 0;
if (obj is Test)
{
IUnknownImpl = new IUnknownVtbl()
var vtbl = new ITestVtbl()
{
QueryInterface = fpQueryInteface,
AddRef = fpAddRef,
Release = fpRelease
},
SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue)
};
var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl));
Marshal.StructureToPtr(vtbl, vtblRaw, false);

var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry));
entryRaw->IID = typeof(ITest).GUID;
entryRaw->Vtable = vtblRaw;
IUnknownImpl = new IUnknownVtbl()
{
QueryInterface = fpQueryInteface,
AddRef = fpAddRef,
Release = fpRelease
},
SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue)
};
var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl));
Marshal.StructureToPtr(vtbl, vtblRaw, false);

entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry));
entryRaw->IID = typeof(ITest).GUID;
entryRaw->Vtable = vtblRaw;
count = 1;
}

count = 1;
return entryRaw;
}

Expand Down Expand Up @@ -75,6 +78,19 @@ public static void ValidateIUnknownImpls()
}
}

static void ForceGC()
{
// Trigger the GC multiple times and then
// wait for all finalizers since that is where
// most of the cleanup occurs.
GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
GC.WaitForPendingFinalizers();
}

static void ValidateComInterfaceCreation()
{
Console.WriteLine($"Running {nameof(ValidateComInterfaceCreation)}...");
Expand Down Expand Up @@ -375,11 +391,7 @@ static void ValidateRuntimeTrackerScenario()

Assert.IsTrue(testWrapperIds.Count <= Test.InstanceCount);

GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
ForceGC();

Assert.IsTrue(testWrapperIds.Count <= Test.InstanceCount);

Expand All @@ -391,11 +403,69 @@ static void ValidateRuntimeTrackerScenario()

testWrapperIds.Clear();

GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
GC.Collect();
ForceGC();
}

unsafe class Derived : ITrackerObjectWrapper
{
public Derived(ComWrappers cw, bool aggregateRefTracker)
: base(cw, aggregateRefTracker)
{ }

[MethodImpl(MethodImplOptions.NoInlining)]
public static WeakReference<Derived> AllocateAndUseBaseType(ComWrappers cw, bool aggregateRefTracker)
{
var derived = new Derived(cw, aggregateRefTracker);

// Use the base type
IntPtr testWrapper = cw.GetOrCreateComInterfaceForObject(new Test(), CreateComInterfaceFlags.TrackerSupport);
int id = derived.AddObjectRef(testWrapper);

// Tell the tracker runtime to release its hold on the base instance.
MockReferenceTrackerRuntime.ReleaseAllTrackerObjects();

// Validate the GC is tracking the entire Derived type.
ForceGC();

derived.DropObjectRef(id);

return new WeakReference<Derived>(derived);
}
}

static void ValidateAggregationWithComObject()
{
Console.WriteLine($"Running {nameof(ValidateAggregationWithComObject)}...");

using var allocTracker = MockReferenceTrackerRuntime.CountTrackerObjectAllocations();
var cw = new TestComWrappers();
WeakReference<Derived> weakRef = Derived.AllocateAndUseBaseType(cw, aggregateRefTracker: false);

ForceGC();

// Validate all instances were cleaned up
Assert.IsFalse(weakRef.TryGetTarget(out _));
Assert.AreEqual(0, allocTracker.GetCount());
}

static void ValidateAggregationWithReferenceTrackerObject()
{
Console.WriteLine($"Running {nameof(ValidateAggregationWithReferenceTrackerObject)}...");

using var allocTracker = MockReferenceTrackerRuntime.CountTrackerObjectAllocations();
var cw = new TestComWrappers();
WeakReference<Derived> weakRef = Derived.AllocateAndUseBaseType(cw, aggregateRefTracker: true);

ForceGC();

// Validate all instances were cleaned up.
Assert.IsFalse(weakRef.TryGetTarget(out _));

// Reference counter cleanup requires additional GCs since the Finalizer is used
// to clean up the Reference Tracker runtime references.
ForceGC();

Assert.AreEqual(0, allocTracker.GetCount());
}

static int Main(string[] doNotUse)
Expand All @@ -410,6 +480,8 @@ static int Main(string[] doNotUse)
ValidateIUnknownImpls();
ValidateBadComWrapperImpl();
ValidateRuntimeTrackerScenario();
ValidateAggregationWithComObject();
ValidateAggregationWithReferenceTrackerObject();
}
catch (Exception e)
{
Expand Down
Loading

0 comments on commit 7221c10

Please sign in to comment.