Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Add Per-assembly Load Native Library callbacks
Browse files Browse the repository at this point in the history
This Change implements the Native Library resolution
Call-backs proposed in https://github.com/dotnet/corefx/issues/32015
  • Loading branch information
swaroop-sridhar authored and swaroop-sridhar committed Jan 15, 2019
1 parent 60557a9 commit 82c8738
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@

namespace System.Runtime.InteropServices
{

/// <summary>
/// A delegate used to resolve native libraries via callback.
/// </summary>
/// <param name="libraryName">The native library to resolve</param>
/// <param name="assembly">The assembly requesting the resolution</param>
/// <param name="DllImportSearchPath?">
/// The DllImportSearchPathsAttribute on the PInvoke, if any.
/// Otherwise, the DllImportSearchPathsAttribute on the assembly, if any.
/// Otherwise null.
/// </param>
/// <returns>The handle for the loaded native library on success, null on failure</returns>
public delegate IntPtr DllImportResolver(string libraryName,
Assembly assembly,
DllImportSearchPath? searchPath);

/// <summary>
/// APIs for managing Native Libraries
/// </summary>
Expand Down Expand Up @@ -58,7 +74,9 @@ public static bool TryLoad(string libraryPath, out IntPtr handle)
/// Otherwise, the flags specified by the DefaultDllImportSearchPaths attribute on the
/// calling assembly (if any) are used.
/// This LoadLibrary() method does not invoke the managed call-backs for native library resolution:
/// * The per-assembly registered callback
/// * AssemblyLoadContext.LoadUnmanagedDll()
/// * AssemblyLoadContext.ResolvingUnmanagedDllEvent
/// </summary>
/// <param name="libraryName">The name of the native library to be loaded</param>
/// <param name="assembly">The assembly loading the native library</param>
Expand Down Expand Up @@ -117,7 +135,6 @@ public static bool TryLoad(string libraryName, Assembly assembly, DllImportSearc
/// No action if the input handle is null.
/// </summary>
/// <param name="handle">The native library handle to be freed</param>
/// <exception cref="System.InvalidOperationException">If the operation fails</exception>
public static void Free(IntPtr handle)
{
FreeLib(handle);
Expand Down Expand Up @@ -161,6 +178,74 @@ public static bool TryGetExport(IntPtr handle, string name, out IntPtr address)
return address != IntPtr.Zero;
}

/// <summary>
/// Map from assembly to native-library resolver.
/// Interop specific fields and properties are generally not added to Assembly class.
/// Therefore, this table uses weak assembly pointers to indirectly achieve
/// similar behavior.
/// </summary>
public static ConditionalWeakTable<Assembly, DllImportResolver> s_nativeDllResolveMap = null;

/// <summary>
/// Set a callback for resolving native library imports from an assembly.
/// This per-assembly resolver is the first attempt to resolve native library loads
/// initiated by this assembly.
///
/// Only one resolver can be registered per assembly.
/// Trying to register a second resolver fails with InvalidOperationException.
/// </summary>
/// <param name="assembly">The assembly for which the resolver is registered</param>
/// <param name="resolver">The resolver callback to register</param>
/// <exception cref="System.ArgumentNullException">If assembly or resolver is null</exception>
/// <exception cref="System.ArgumentException">If a resolver is already set for this assembly</exception>
public static void SetDllImportResolver(Assembly assembly, DllImportResolver resolver)
{
if (assembly == null)
throw new ArgumentNullException(nameof(assembly));
if (resolver == null)
throw new ArgumentNullException(nameof(resolver));
if (!(assembly is RuntimeAssembly))
throw new ArgumentException(SR.Argument_MustBeRuntimeAssembly);

if (s_nativeDllResolveMap == null)
{
s_nativeDllResolveMap = new ConditionalWeakTable<Assembly, DllImportResolver>();
}

try
{
s_nativeDllResolveMap.Add(assembly, resolver);
}
catch (ArgumentException e)
{
// ConditionalWealTable throws ArgumentException if the Key already exists
throw new InvalidOperationException("Resolver is alerady Set for the Assembly");
}
}

/// <summary>
/// The helper function that calls the per-assembly native-library resolver
/// if one is registered for this assembly.
/// </summary>
/// <param name="libraryName">The native library to load</param>
/// <param name="assembly">The assembly trying load the native library</param>
/// <param name="hasDllImportSearchPathFlags">If the pInvoke has DefaultDllImportSearchPathAttribute</param>
/// <param name="dllImportSearchPathFlags">If hasdllImportSearchPathFlags is true, the flags in
/// DefaultDllImportSearchPathAttribute; meaningless otherwise </param>
/// <returns>The handle for the loaded library on success. Null on failure.</returns>
internal static IntPtr LoadLibraryCallbackStub(string libraryName, Assembly assembly,
bool hasDllImportSearchPathFlags, uint dllImportSearchPathFlags)
{
DllImportResolver resolver;

if (!s_nativeDllResolveMap.TryGetValue(assembly, out resolver))
{
return IntPtr.Zero;
}

return resolver(libraryName, assembly, hasDllImportSearchPathFlags ? (DllImportSearchPath?)dllImportSearchPathFlags : null);
}

/// External functions that implement the NativeLibrary interface

[DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
Expand Down
1 change: 1 addition & 0 deletions src/vm/callhelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ enum DispatchCallSimpleFlags
#define STRINGREF_TO_ARGHOLDER(x) (LPVOID)STRINGREFToObject(x)
#define PTR_TO_ARGHOLDER(x) (LPVOID)x
#define DWORD_TO_ARGHOLDER(x) (LPVOID)(SIZE_T)x
#define BOOL_TO_ARGHOLDER(x) DWORD_TO_ARGHOLDER(!!(x))

#define INIT_VARIABLES(count) \
DWORD __numArgs = count; \
Expand Down
197 changes: 72 additions & 125 deletions src/vm/dllimport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@
#include "clr/fs/path.h"
using namespace clr::fs;

// The Bit 0x2 has different semantics in DllImportSearchPath and LoadLibraryExA flags.
// In DllImportSearchPath enum, bit 0x2 represents SearchAssemblyDirectory -- which is performed by CLR.
// Unlike other bits in this enum, this bit shouldn't be directly passed on to LoadLibrary()
#define DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY 0x2

// remove when we get an updated SDK
#define LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR 0x00000100
#define LOAD_LIBRARY_SEARCH_DEFAULT_DIRS 0x00001000
Expand Down Expand Up @@ -6127,6 +6122,57 @@ bool NDirect::s_fSecureLoadLibrarySupported = false;
#define PLATFORM_SHARED_LIB_PREFIX_W W("")
#endif // !FEATURE_PAL

// The Bit 0x2 has different semantics in DllImportSearchPath and LoadLibraryExA flags.
// In DllImportSearchPath enum, bit 0x2 represents SearchAssemblyDirectory -- which is performed by CLR.
// Unlike other bits in this enum, this bit shouldn't be directly passed on to LoadLibrary()
#define DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY 0x2

// DllImportSearchPathFlags is a special enumeration, whose values are tied closely with LoadLibrary flags.
// There is no "default" value DllImportSearchPathFlags. In the absence of DllImportSearchPath attribute,
// CoreCLR's LoadLibrary implementation uses the following defaults.
// Other implementations of LoadLibrary callbacks/events are free to use other default conventions.
void GetDefaultDllImportSearchPathFlags(DWORD *dllImportSearchPathFlags, BOOL *searchAssemblyDirectory)
{
STANDARD_VM_CONTRACT;

*searchAssemblyDirectory = TRUE;
*dllImportSearchPathFlags = 0;
}

// If a module has the DllImportSearchPathAttribute, get DllImportSearchPathFlags from it, and return true.
// Otherwise, get the default value for the flags, and return false.
BOOL GetDllImportSearchPathFlags(Module *pModule, DWORD *dllImportSearchPathFlags, BOOL *searchAssemblyDirectory)
{
STANDARD_VM_CONTRACT;

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
*dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
*searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
return true;
}

GetDefaultDllImportSearchPathFlags(dllImportSearchPathFlags, searchAssemblyDirectory);
return false;
}

// If a pInvoke has DllImportSearchPathAttribute, get DllImportSearchPathFlags from it, and returns true.
// Otherwise, if the containing assembly has the DllImportSearchPathAttribute, get DllImportSearchPathFlags from it, and returns true.
// Otherwise, return false (out parameters are untouched).
BOOL GetDllImportSearchPathFlags(NDirectMethodDesc * pMD, DWORD *dllImportSearchPathFlags, BOOL *searchAssemblyDirectory)
{
STANDARD_VM_CONTRACT;

if (pMD->HasDefaultDllImportSearchPathsAttribute())
{
*dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
*searchAssemblyDirectory = pMD->DllImportSearchAssemblyDirectory();
return true;
}

return GetDllImportSearchPathFlags(pMD->GetModule(), dllImportSearchPathFlags, searchAssemblyDirectory);
}

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryFromPath(LPCWSTR libraryPath, BOOL throwOnError)
{
Expand Down Expand Up @@ -6165,25 +6211,21 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *
LoadLibErrorTracker errorTracker;

// First checks if a default dllImportSearchPathFlags was passed in, if so, use that value.
// Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlags = 0;
// Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute.
// If so, use that value.
BOOL searchAssemblyDirectory;
DWORD dllImportSearchPathFlags;

if (hasDllImportSearchFlags)
{
dllImportSearchPathFlags = dllImportSearchFlags & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
searchAssemblyDirectory = dllImportSearchFlags & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;

}
else
else
{
Module * pModule = callingAssembly->GetManifestModule();

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
GetDllImportSearchPathFlags(callingAssembly->GetManifestModule(),
&dllImportSearchPathFlags, &searchAssemblyDirectory);
}

NATIVE_LIBRARY_HANDLE hmod =
Expand All @@ -6203,26 +6245,10 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD
{
STANDARD_VM_CONTRACT;

// First checks if the method has DefaultDllImportSearchPathsAttribute. If so, use that value.
// Otherwise checks if the assembly has the attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlags = 0;
BOOL searchAssemblyDirectory;
DWORD dllImportSearchPathFlags;

if (pMD->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pMD->DllImportSearchAssemblyDirectory();
}
else
{
Module * pModule = pMD->GetModule();

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
}
GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory);

Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName);
Expand Down Expand Up @@ -6452,18 +6478,20 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaCallback(NDirectMethodDesc *
{
STANDARD_VM_CONTRACT;

NATIVE_LIBRARY_HANDLE handle = NULL;

DWORD dllImportSearchPathFlags = 0;
BOOL hasDllImportSearchPathFlags = pMD->HasDefaultDllImportSearchPathsAttribute();
if (hasDllImportSearchPathFlags)
if (pMD->GetModule()->IsSystem())
{
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
if (pMD->DllImportSearchAssemblyDirectory())
dllImportSearchPathFlags |= DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
// Don't attempt to callback on Corelib itself.
// The LoadLibrary callback stub is managed code that requires CoreLib
return NULL;
}

DWORD dllImportSearchPathFlags;
BOOL searchAssemblyDirectory;
BOOL hasDllImportSearchPathFlags = GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory);
dllImportSearchPathFlags |= searchAssemblyDirectory ? DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY : 0;

Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
NATIVE_LIBRARY_HANDLE handle = NULL;

GCX_COOP();

Expand Down Expand Up @@ -6491,87 +6519,6 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaCallback(NDirectMethodDesc *
return handle;
}

// Return the AssemblyLoadContext for an assembly
INT_PTR GetManagedAssemblyLoadContext(Assembly* pAssembly)
{
STANDARD_VM_CONTRACT;

PTR_ICLRPrivBinder pBindingContext = pAssembly->GetManifestFile()->GetBindingContext();
if (pBindingContext == NULL)
{
// GetBindingContext() returns NULL for System.Private.CoreLib
return NULL;
}

UINT_PTR assemblyBinderID = 0;
IfFailThrow(pBindingContext->GetBinderID(&assemblyBinderID));

AppDomain *pDomain = GetAppDomain();
ICLRPrivBinder *pCurrentBinder = reinterpret_cast<ICLRPrivBinder *>(assemblyBinderID);

#ifdef FEATURE_COMINTEROP
if (AreSameBinderInstance(pCurrentBinder, pDomain->GetWinRtBinder()))
{
// No ALC associated handle with WinRT Binders.
return NULL;
}
#endif // FEATURE_COMINTEROP

// The code here deals with two implementations of ICLRPrivBinder interface:
// - CLRPrivBinderCoreCLR for the TPA binder in the default ALC, and
// - CLRPrivBinderAssemblyLoadContext for custom ALCs.
// in order obtain the associated ALC handle.
INT_PTR ptrManagedAssemblyLoadContext = AreSameBinderInstance(pCurrentBinder, pDomain->GetTPABinderContext())
? ((CLRPrivBinderCoreCLR *)pCurrentBinder)->GetManagedAssemblyLoadContext()
: ((CLRPrivBinderAssemblyLoadContext *)pCurrentBinder)->GetManagedAssemblyLoadContext();

return ptrManagedAssemblyLoadContext;
}

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaEvent(NDirectMethodDesc * pMD, PCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;

NATIVE_LIBRARY_HANDLE hmod = NULL;
Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
INT_PTR ptrManagedAssemblyLoadContext = GetManagedAssemblyLoadContext(pAssembly);

if (ptrManagedAssemblyLoadContext == NULL)
{
return NULL;
}

GCX_COOP();

struct {
STRINGREF DllName;
OBJECTREF AssemblyRef;
} gc = { NULL, NULL };

GCPROTECT_BEGIN(gc);

gc.DllName = StringObject::NewString(wszLibName);
gc.AssemblyRef = pAssembly->GetExposedObject();

// Prepare to invoke System.Runtime.Loader.AssemblyLoadContext.ResolveUnmanagedDllUsingEvent method
// While ResolveUnmanagedDllUsingEvent() could compute the AssemblyLoadContext using the AssemblyRef
// argument, it will involve another pInvoke to the runtime. So AssemblyLoadContext is passed in
// as an additional argument.
PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__RESOLVEUNMANAGEDDLLUSINGEVENT);
DECLARE_ARGHOLDER_ARRAY(args, 3);
args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.DllName);
args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.AssemblyRef);
args[ARGNUM_2] = PTR_TO_ARGHOLDER(ptrManagedAssemblyLoadContext);

// Make the call
CALL_MANAGED_METHOD(hmod, NATIVE_LIBRARY_HANDLE, args);

GCPROTECT_END();

return hmod;
}

// Try to load the module alongside the assembly where the PInvoke was declared.
NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
{
Expand Down Expand Up @@ -6842,7 +6789,7 @@ HINSTANCE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLibErrorTracke
if ( !name || !*name )
return NULL;

PREFIX_ASSUME( name != NULL );
PREFIX_ASSUME( name != NULL );
MAKE_WIDEPTR_FROMUTF8( wszLibName, name );

ModuleHandleHolder hmod = LoadLibraryModuleViaCallback(pMD, wszLibName);
Expand Down
Loading

0 comments on commit 82c8738

Please sign in to comment.