Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Add Per-assembly Load Native Library callbacks
Browse files Browse the repository at this point in the history
This Change implements the Native Library resolution
Call-backs proposed in https://github.com/dotnet/corefx/issues/32015

public static bool RegisterDllImportResolver(
    Assembly assembly,
    Func<string, Assembly, DllImportSearchPath, IntPtr> callback
);

This API is not yet approved, and the API contracts in CoreFX
will not be added until the API approval is complete.
In the meantime, we want to have the code reviewed, tested, and
avaiable in CoreCLR.
  • Loading branch information
swaroop-sridhar committed Dec 15, 2018
1 parent ca65764 commit ea09225
Show file tree
Hide file tree
Showing 19 changed files with 546 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public static partial class Marshal
internal static Guid IID_IUnknown = new Guid("00000000-0000-0000-C000-000000000046");
#endif //FEATURE_COMINTEROP

static Marshal()
{
nativeDllResolveMap = new ConditionalWeakTable<Assembly, Func<string, Assembly, DllImportSearchPath?, IntPtr>>();
}

private const int LMEM_FIXED = 0;
private const int LMEM_MOVEABLE = 2;
#if !FEATURE_PAL
Expand Down Expand Up @@ -1805,6 +1810,7 @@ public static bool TryLoadLibrary(string libraryPath, out IntPtr handle)
/// calling assembly (if any) are used.
/// This LoadLibrary() method does not invoke the managed call-backs for native library resolution:
/// * AssemblyLoadContext.LoadUnmanagedDll()
/// * The per-assembly registered callback
/// </summary>
/// <param name="libraryName">The name of the native library to be loaded</param>
/// <param name="assembly">The assembly loading the native library</param>
Expand Down Expand Up @@ -1907,13 +1913,87 @@ public static bool TryGetLibraryExport(IntPtr handle, string name, out IntPtr ad
return address != IntPtr.Zero;
}

/// <summary>
/// Map from assembly to native-library-resolution-callback.
/// Generally interop specific fields and properties are not added to assembly.
/// Therefore, this table uses weak assembly pointers to indirectly achieve
/// similar behavior.
/// </summary>
public static ConditionalWeakTable<Assembly, Func<string, Assembly, DllImportSearchPath?, IntPtr>> nativeDllResolveMap;

/// <summary>
/// Register a callback for resolving native library imports from an assembly
/// This per-assembly callback is the first attempt to resolve native library loads
/// initiated by this assembly.
///
/// Only one callback can be registered per assembly. Trying to register a second
/// callback fails (with the return value false).
///
/// The callback method itself takes the following parameters
/// - The name of the library to be loaded (string)
/// - The assembly initiating the native library load
/// - The DllImportSearchPath flags from the attributes on the PInvoke, if any (null otherwise).
/// This argument doesn't include the flags from any attributes on the assembly itself.
/// and returns
/// - The handle to the loaded native library (on success) or null (on failure)
/// The parameters on this callback are such that they can be directly passed to
/// Marhall.LoadLibrary(libraryName, assembly, dllImportSearchPath) to approximately achieve
/// the default load behavior.
///
/// </summary>
/// <param name="assembly">The assembly for which the callback is registered</param>
/// <param name="callBack">The callback to register</param>
/// <exception cref="System.ArgumentNullException">If assembly or callback is null</exception>
/// <returns>True on success, false otherwise</returns>
public static bool RegisterDllImportResolver(Assembly assembly, Func<string, Assembly, DllImportSearchPath?, IntPtr> callBack)
{
if (assembly == null)
throw new ArgumentNullException(nameof(assembly));
if (callBack == null)
throw new ArgumentNullException(nameof(callBack));
if (!(assembly is RuntimeAssembly))
throw new ArgumentException(SR.Argument_MustBeRuntimeAssembly);

Func<string, Assembly, DllImportSearchPath?, IntPtr> existingCallback;
if (nativeDllResolveMap.TryGetValue(assembly, out existingCallback))
{
return false;
}

nativeDllResolveMap.Add(assembly, callBack);
return true;
}

/// <summary>
/// The helper function that calls the per-assembly native-library resolver
/// if one is registered for this assembly.
/// </summary>
/// <param name="libraryName">The native library to load</param>
/// <param name="assembly">The assembly trying load the native library</param>
/// <param name="hasDllImportSearchPathFlags">If the pInvoke has DefaultDllImportSearchPathAttribute</param>
/// <param name="dllImportSearchPathFlags">If hasdllImportSearchPathFlags is true, the flags in
/// DefaultDllImportSearchPathAttribute; meaningless otherwise </param>
/// <returns>The handle for the loaded library on success. Null on failure.</returns>
internal static IntPtr LoadLibraryCallbackStub(string libraryName, Assembly assembly,
bool hasDllImportSearchPathFlags, uint dllImportSearchPathFlags)
{
Func<string, Assembly, DllImportSearchPath?, IntPtr> callBack;

if (!nativeDllResolveMap.TryGetValue(assembly, out callBack))
{
return IntPtr.Zero;
}

return callBack(libraryName, assembly, hasDllImportSearchPathFlags? (DllImportSearchPath?)dllImportSearchPathFlags : null);
}

/// External functions that implement the NativeLibrary interface

[DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
internal static extern IntPtr LoadLibraryFromPath(string libraryName, bool throwOnError);
[DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
internal static extern IntPtr LoadLibraryByName(string libraryName, RuntimeAssembly callingAssembly,
bool hasDllImportSearchPathFlag, uint dllImportSearchPathFlag,
bool hasDllImportSearchPathFlags, uint dllImportSearchPathFlags,
bool throwOnError);
[DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
internal static extern void FreeNativeLibrary(IntPtr handle);
Expand Down
1 change: 1 addition & 0 deletions src/vm/callhelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ enum DispatchCallSimpleFlags
#define STRINGREF_TO_ARGHOLDER(x) (LPVOID)STRINGREFToObject(x)
#define PTR_TO_ARGHOLDER(x) (LPVOID)x
#define DWORD_TO_ARGHOLDER(x) (LPVOID)(SIZE_T)x
#define BOOL_TO_ARGHOLDER(x) DWORD_TO_ARGHOLDER(x)

#define INIT_VARIABLES(count) \
DWORD __numArgs = count; \
Expand Down
124 changes: 96 additions & 28 deletions src/vm/dllimport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6157,7 +6157,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryFromPath(LPCWSTR libraryPath, BOOL thr

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *callingAssembly,
BOOL hasDllImportSearchFlag, DWORD dllImportSearchFlag,
BOOL hasDllImportSearchFlags, DWORD dllImportSearchFlags,
BOOL throwOnError)
{
CONTRACTL
Expand All @@ -6170,15 +6170,15 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *

LoadLibErrorTracker errorTracker;

// First checks if a default DllImportSearchPathFlag was passed in, if so, use that value.
// First checks if a default dllImportSearchPathFlags was passed in, if so, use that value.
// Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlag = 0;
DWORD dllImportSearchPathFlags = 0;

if (hasDllImportSearchFlag)
if (hasDllImportSearchFlags)
{
dllImportSearchPathFlag = dllImportSearchFlag & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
searchAssemblyDirectory = dllImportSearchFlag & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
dllImportSearchPathFlags = dllImportSearchFlags & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
searchAssemblyDirectory = dllImportSearchFlags & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;

}
else
Expand All @@ -6187,13 +6187,13 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
}

NATIVE_LIBRARY_HANDLE hmod =
LoadLibraryModuleBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlag, &errorTracker, libraryName);
LoadLibraryModuleBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName);

if (throwOnError && (hmod == nullptr))
{
Expand All @@ -6212,11 +6212,11 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD
// First checks if the method has DefaultDllImportSearchPathsAttribute. If so, use that value.
// Otherwise checks if the assembly has the attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlag = 0;
DWORD dllImportSearchPathFlags = 0;

if (pMD->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pMD->DllImportSearchAssemblyDirectory();
}
else
Expand All @@ -6225,13 +6225,13 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
}

Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlag, pErrorTracker, wszLibName);
return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName);
}

// static
Expand Down Expand Up @@ -6280,23 +6280,32 @@ INT_PTR NDirect::GetNativeLibraryExport(NATIVE_LIBRARY_HANDLE handle, LPCWSTR sy
return address;
}

#ifndef PLATFORM_UNIX
BOOL IsWindowsAPI(PCWSTR wszLibName)
{
// This is replicating quick check from the OS implementation of api sets.
return SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 ||
SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0;
}
#endif // !PLATFORM_UNIX

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, AppDomain* pDomain, PCWSTR wszLibName)
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, PCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;
//Dynamic Pinvoke Support:
//Check if we need to provide the host a chance to provide the unmanaged dll

#ifndef PLATFORM_UNIX
// Prevent Overriding of Windows API sets.
// This is replicating quick check from the OS implementation of api sets.
if (SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0)
if (IsWindowsAPI(wszLibName))
{
// Prevent Overriding of Windows API sets.
return NULL;
}
#endif
#endif // !PLATFORM_UNIX

LPVOID hmod = NULL;
AppDomain* pDomain = GetAppDomain();
CLRPrivBinderCoreCLR *pTPABinder = pDomain->GetTPABinderContext();
Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();

Expand Down Expand Up @@ -6362,6 +6371,55 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD,
return (NATIVE_LIBRARY_HANDLE)hmod;
}

NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaCallBack(NDirectMethodDesc * pMD, LPCWSTR wszLibName)
{
#ifndef PLATFORM_UNIX
if (IsWindowsAPI(wszLibName))
{
// Prevent Overriding of Windows API sets.
return NULL;
}
#endif // !PLATFORM_UNIX

DWORD dllImportSearchPathFlags = 0;
BOOL hasDllImportSearchPathFlags = pMD->HasDefaultDllImportSearchPathsAttribute();
if (hasDllImportSearchPathFlags)
{
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
if (pMD->DllImportSearchAssemblyDirectory())
dllImportSearchPathFlags |= DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
}

GCX_COOP();

struct {
STRINGREF libNameRef;
OBJECTREF assemblyRef;
} protect;


Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
protect.libNameRef = StringObject::NewString(wszLibName);
protect.assemblyRef = pAssembly->GetExposedObject();

NATIVE_LIBRARY_HANDLE handle = NULL;

GCPROTECT_BEGIN(protect);

PREPARE_NONVIRTUAL_CALLSITE(METHOD__MARSHAL__LOADLIBRARYCALLBACKSTUB);
DECLARE_ARGHOLDER_ARRAY(args, 4);
args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(protect.libNameRef);
args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(protect.assemblyRef);
args[ARGNUM_2] = BOOL_TO_ARGHOLDER(hasDllImportSearchPathFlags);
args[ARGNUM_3] = DWORD_TO_ARGHOLDER(dllImportSearchPathFlags);

// Make the call
CALL_MANAGED_METHOD(handle, NATIVE_LIBRARY_HANDLE, args);
GCPROTECT_END();

return handle;
}

// Try to load the module alongside the assembly where the PInvoke was declared.
NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
{
Expand All @@ -6385,11 +6443,12 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssem
}

// Try to load the module from the native DLL search directories
NATIVE_LIBRARY_HANDLE NDirect::LoadFromNativeDllSearchDirectories(AppDomain* pDomain, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
NATIVE_LIBRARY_HANDLE NDirect::LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
{
STANDARD_VM_CONTRACT;

NATIVE_LIBRARY_HANDLE hmod = NULL;
AppDomain* pDomain = GetAppDomain();

if (pDomain->HasNativeDllSearchDirectories())
{
Expand Down Expand Up @@ -6511,7 +6570,7 @@ static void DetermineLibNameVariations(const WCHAR** libNameVariations, int* num
// Search for the library and variants of its name in probing directories.
//static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssembly,
BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlag,
BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags,
LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;
Expand All @@ -6521,7 +6580,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
#if defined(FEATURE_CORESYSTEM) && !defined(PLATFORM_UNIX)
// Try to go straight to System32 for Windows API sets. This is replicating quick check from
// the OS implementation of api sets.
if (SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0)
if (IsWindowsAPI(wszLibName))
{
hmod = LocalLoadLibraryHelper(wszLibName, LOAD_LIBRARY_SEARCH_SYSTEM32, pErrorTracker);
if (hmod != NULL)
Expand Down Expand Up @@ -6549,7 +6608,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
currLibNameVariation.Printf(prefixSuffixCombinations[i], PLATFORM_SHARED_LIB_PREFIX_W, wszLibName, PLATFORM_SHARED_LIB_SUFFIX_W);

// NATIVE_DLL_SEARCH_DIRECTORIES set by host is considered well known path
hmod = LoadFromNativeDllSearchDirectories(pDomain, currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker);
hmod = LoadFromNativeDllSearchDirectories(currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
Expand All @@ -6558,11 +6617,11 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
if (!libNameIsRelativePath)
{
DWORD flags = loadWithAlteredPathFlags;
if ((dllImportSearchPathFlag & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0)
if ((dllImportSearchPathFlags & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0)
{
// LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR is the only flag affecting absolute path. Don't OR the flags
// unconditionally as all absolute path P/Invokes could then lose LOAD_WITH_ALTERED_SEARCH_PATH.
flags |= dllImportSearchPathFlag;
flags |= dllImportSearchPathFlags;
}

hmod = LocalLoadLibraryHelper(currLibNameVariation, flags, pErrorTracker);
Expand All @@ -6573,14 +6632,14 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
}
else if ((callingAssembly != nullptr) && searchAssemblyDirectory)
{
hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlag, pErrorTracker);
hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
}
}

hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlag, pErrorTracker);
hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
Expand Down Expand Up @@ -6610,7 +6669,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
Assembly *pAssembly = spec.LoadAssembly(FILE_LOADED);
Module *pModule = pAssembly->FindModuleByName(szLibName);

hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlag, pErrorTracker);
hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker);
}
}

Expand All @@ -6631,19 +6690,28 @@ HINSTANCE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLibErrorTracke
if ( !name || !*name )
return NULL;

ModuleHandleHolder hmod;

PREFIX_ASSUME( name != NULL );
MAKE_WIDEPTR_FROMUTF8( wszLibName, name );

ModuleHandleHolder hmod = LoadLibraryModuleViaCallBack(pMD, wszLibName);
if (hmod != NULL)
{
#ifdef FEATURE_PAL
// Register the system library handle with PAL and get a PAL library handle
hmod = PAL_RegisterLibraryDirect(hmod, wszLibName);
#endif // FEATURE_PAL
return hmod.Extract();
}

AppDomain* pDomain = GetAppDomain();

// AssemblyLoadContext is not supported in AppX mode and thus,
// we should not perform PInvoke resolution via it when operating in
// AppX mode.
if (!AppX::IsAppXProcess())
{
hmod = LoadLibraryModuleViaHost(pMD, pDomain, wszLibName);
hmod = LoadLibraryModuleViaHost(pMD, wszLibName);
if (hmod != NULL)
{
#ifdef FEATURE_PAL
Expand Down
Loading

0 comments on commit ea09225

Please sign in to comment.