Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for IMallocSpy #71106

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/coreclr/vm/olecontexthelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ HRESULT GetCurrentObjCtx(IUnknown **ppObjCtx)
CONTRACTL
{
NOTHROW;
GC_NOTRIGGER;
GC_TRIGGERS; // This can occur if IMallocSpy is implemented in managed code.
MODE_ANY;
PRECONDITION(CheckPointer(ppObjCtx));
#ifdef FEATURE_COMINTEROP
Expand All @@ -33,7 +33,7 @@ LPVOID SetupOleContext()
CONTRACT (LPVOID)
{
NOTHROW;
GC_NOTRIGGER;
GC_TRIGGERS;
MODE_ANY;
ENTRY_POINT;
POSTCONDITION(CheckPointer(RETVAL, NULL_OK));
Expand Down
37 changes: 0 additions & 37 deletions src/coreclr/vm/stubhelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,6 @@ FORCEINLINE static SOleTlsData *GetOrCreateOleTlsData()
return pOleTlsData;
}

FORCEINLINE static void *GetCOMIPFromRCW_GetTargetNoInterception(IUnknown *pUnk, ComPlusCallInfo *pComInfo)
{
LIMITED_METHOD_CONTRACT;

LPVOID *lpVtbl = *(LPVOID **)pUnk;
return lpVtbl[pComInfo->m_cachedComSlot];
}

FORCEINLINE static IUnknown *GetCOMIPFromRCW_GetIUnknownFromRCWCache(RCW *pRCW, MethodTable * pItfMT)
{
LIMITED_METHOD_CONTRACT;
Expand All @@ -250,35 +242,6 @@ FORCEINLINE static IUnknown *GetCOMIPFromRCW_GetIUnknownFromRCWCache(RCW *pRCW,
return NULL;
}

// Like GetCOMIPFromRCW_GetIUnknownFromRCWCache but also computes the target. This is a couple of instructions
// faster than GetCOMIPFromRCW_GetIUnknownFromRCWCache + GetCOMIPFromRCW_GetTargetNoInterception.
FORCEINLINE static IUnknown *GetCOMIPFromRCW_GetIUnknownFromRCWCache_NoInterception(RCW *pRCW, ComPlusCallInfo *pComInfo, void **ppTarget)
{
LIMITED_METHOD_CONTRACT;

// The code in this helper is the "fast path" that used to be generated directly
// to compiled ML stubs. The idea is to aim for an efficient RCW cache hit.
SOleTlsData *pOleTlsData = GetOrCreateOleTlsData();
MethodTable *pItfMT = pComInfo->m_pInterfaceMT;

// test for free-threaded after testing for context match to optimize for apartment-bound objects
if (pOleTlsData->pCurrentCtx == pRCW->GetWrapperCtxCookie() || pRCW->IsFreeThreaded())
{
for (int i = 0; i < INTERFACE_ENTRY_CACHE_SIZE; i++)
{
if (pRCW->m_aInterfaceEntries[i].m_pMT == pItfMT)
{
IUnknown *pUnk = pRCW->m_aInterfaceEntries[i].m_pUnknown;
_ASSERTE(pUnk != NULL);
*ppTarget = GetCOMIPFromRCW_GetTargetNoInterception(pUnk, pComInfo);
return pUnk;
}
}
}

return NULL;
}

FORCEINLINE static void *GetCOMIPFromRCW_GetTarget(IUnknown *pUnk, ComPlusCallInfo *pComInfo)
{
LIMITED_METHOD_CONTRACT;
Expand Down
82 changes: 82 additions & 0 deletions src/tests/Interop/COM/ExtensionPoints/ExtensionPoints.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

using COM;

using TestLibrary;
using Xunit;

public class ExtensionPoints
{
unsafe class MallocSpy : IMallocSpy
{
private int _called = 0;
public int Called => _called;

public virtual nuint PreAlloc(nuint cbRequest)
{
_called++;
return cbRequest;
}

public virtual unsafe void* PostAlloc(void* pActual) => pActual;
public virtual unsafe void* PreFree(void* pRequest, [MarshalAs(UnmanagedType.Bool)] bool fSpyed)
{
_called++;
return pRequest;
}

public virtual void PostFree([MarshalAs(UnmanagedType.Bool)] bool fSpyed) { }
public virtual unsafe nuint PreRealloc(void* pRequest, nuint cbRequest, void** ppNewRequest, [MarshalAs(UnmanagedType.Bool)] bool fSpyed) => cbRequest;
public virtual unsafe void* PostRealloc(void* pActual, [MarshalAs(UnmanagedType.Bool)] bool fSpyed) => pActual;
public virtual unsafe void* PreGetSize(void* pRequest, [MarshalAs(UnmanagedType.Bool)] bool fSpyed) => pRequest;
public virtual nuint PostGetSize(nuint cbActual, [MarshalAs(UnmanagedType.Bool)] bool fSpyed) => cbActual;
public virtual unsafe void* PreDidAlloc(void* pRequest, [MarshalAs(UnmanagedType.Bool)] bool fSpyed) => pRequest;
public virtual unsafe int PostDidAlloc(void* pRequest, [MarshalAs(UnmanagedType.Bool)] bool fSpyed, int fActual) => fActual;
public virtual void PreHeapMinimize() { }
public virtual void PostHeapMinimize() { }
}

[Fact]
public static unsafe void Validate_Managed_IMallocSpy()
{
Console.WriteLine($"Running {nameof(Validate_Managed_IMallocSpy)}...");
var mallocSpy = new MallocSpy();
int result = Ole32.CoRegisterMallocSpy(mallocSpy);
Assert.Equal(0, result);
try
{
var arr = new [] { "", "", "", "", null };

// The goal of this test is to trigger paths in which CoTaskMemAlloc
// will be implicitly used and validate that the registered managed
// IMallocSpy can be called successful. The validation is for confirming
// the transition to Preemptive mode was performed.
//
// Casting the function pointer to one in which an IL stub will be
// used to marshal the string[].
var fptr = (delegate*unmanaged<string[], int>)(delegate*unmanaged<char**, int>)&ArrayLen;
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
int len = fptr(arr);
Assert.Equal(arr.Length - 1, len);

// Allocate 1 for the array, 1 for each non-null element, then double it for Free.
Assert.Equal((1 + (arr.Length - 1)) * 2, mallocSpy.Called);
}
finally
{
Ole32.CoRevokeMallocSpy();
}

[UnmanagedCallersOnly]
static int ArrayLen(char** ptr)
{
char** begin = ptr;
while (*ptr != null)
ptr++;
return (int)(ptr - begin);
}
}
}
10 changes: 10 additions & 0 deletions src/tests/Interop/COM/ExtensionPoints/ExtensionPoints.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="ExtensionPoints.cs" />
<Compile Include="Interfaces.cs" />
</ItemGroup>
</Project>
78 changes: 78 additions & 0 deletions src/tests/Interop/COM/ExtensionPoints/Interfaces.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;

namespace COM
{
static class Ole32
{
[DllImport(nameof(Ole32), ExactSpelling = true)]
public static extern int CoRegisterMallocSpy(IMallocSpy mallocSpy);

[DllImport(nameof(Ole32), ExactSpelling = true)]
public static extern int CoRevokeMallocSpy();
}

[ComImport]
[Guid("0000001d-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
public unsafe interface IMallocSpy
{
[PreserveSig]
nuint PreAlloc(
nuint cbRequest);

[PreserveSig]
void* PostAlloc(
void* pActual);

[PreserveSig]
void* PreFree(
void* pRequest,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
void PostFree(
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
nuint PreRealloc(
void* pRequest,
nuint cbRequest,
void** ppNewRequest,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
void* PostRealloc(
void* pActual,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
void* PreGetSize(
void* pRequest,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
nuint PostGetSize(
nuint cbActual,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
void* PreDidAlloc(
void* pRequest,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed);

[PreserveSig]
int PostDidAlloc(
void* pRequest,
[MarshalAs(UnmanagedType.Bool)] bool fSpyed,
int fActual);

[PreserveSig]
void PreHeapMinimize();

[PreserveSig]
void PostHeapMinimize();
}
}