diff --git a/prov/util/src/util_mem_hooks.c b/prov/util/src/util_mem_hooks.c index 05598a259a0..16b0ebc17d6 100644 --- a/prov/util/src/util_mem_hooks.c +++ b/prov/util/src/util_mem_hooks.c @@ -67,7 +67,8 @@ struct ofi_dl_intercept { enum { OFI_INTERCEPT_DLOPEN, OFI_INTERCEPT_MMAP, - OFI_INTERCEPT_MUNMAP + OFI_INTERCEPT_MUNMAP, + OFI_INTERCEPT_MAX }; static void *ofi_intercept_dlopen(const char *filename, int flag); @@ -283,28 +284,21 @@ static void ofi_restore_intercepts(void) { struct ofi_intercept *intercept; - fastlock_acquire(&memhooks_monitor->lock); dlist_foreach_container(&memhooks.intercept_list, struct ofi_intercept, intercept, entry) { dl_iterate_phdr(ofi_restore_phdr_handler, intercept); } - fastlock_release(&memhooks_monitor->lock); } static int ofi_intercept_symbol(struct ofi_intercept *intercept, void **real_func) { int ret; - /* - * Take lock first to handle a possible race where dlopen() is called - * from another thread and we may end up not patching it. - */ FI_DBG(&core_prov, FI_LOG_MR, "intercepting symbol %s\n", intercept->symbol); - fastlock_acquire(&memhooks_monitor->lock); ret = dl_iterate_phdr(ofi_intercept_phdr_handler, intercept); if (ret) - goto unlock; + return ret; *real_func = dlsym(RTLD_DEFAULT, intercept->symbol); if (*real_func == intercept->our_func) { @@ -316,11 +310,10 @@ static int ofi_intercept_symbol(struct ofi_intercept *intercept, void **real_fun FI_DBG(&core_prov, FI_LOG_MR, "could not find symbol %s\n", intercept->symbol); ret = -FI_ENOMEM; - goto unlock; + return ret; } - dlist_insert_tail(&memhooks.intercept_list, &intercept->entry); -unlock: - fastlock_release(&memhooks_monitor->lock); + dlist_insert_tail(&intercept->entry, &memhooks.intercept_list); + return ret; } @@ -365,7 +358,7 @@ static void ofi_memhooks_unsubscribe(struct ofi_mem_monitor *monitor, int ofi_memhooks_init(void) { - int ret; + int i, ret; /* TODO: remove once cleanup is written */ if (memhooks_monitor->subscribe == ofi_memhooks_subscribe) @@ -375,6 +368,9 @@ int ofi_memhooks_init(void) memhooks_monitor->unsubscribe = ofi_memhooks_unsubscribe; dlist_init(&memhooks.intercept_list); + for (i = 0; i < OFI_INTERCEPT_MAX; ++i) + dlist_init(&intercepts[i].dl_intercept_list); + ret = ofi_intercept_symbol(&intercepts[OFI_INTERCEPT_DLOPEN], (void **) &real_calls.dlopen); if (ret) { diff --git a/prov/util/src/util_mem_monitor.c b/prov/util/src/util_mem_monitor.c index 44e3277fe9d..3e7cc31d200 100644 --- a/prov/util/src/util_mem_monitor.c +++ b/prov/util/src/util_mem_monitor.c @@ -34,7 +34,6 @@ #include - static struct ofi_uffd uffd; struct ofi_mem_monitor *uffd_monitor = &uffd.monitor; @@ -53,9 +52,9 @@ void ofi_monitor_init(void) dlist_init(&memhooks_monitor->list); #if HAVE_UFFD_UNMAP -struct ofi_mem_monitor *default_monitor = uffd_monitor; +default_monitor = uffd_monitor; #else -struct ofi_mem_monitor *default_monitor = memhooks_monitor; +default_monitor = memhooks_monitor; #endif fi_param_define(NULL, "mr_cache_max_size", FI_PARAM_SIZE_T, @@ -93,6 +92,14 @@ struct ofi_mem_monitor *default_monitor = memhooks_monitor; if (!cache_params.max_size) cache_params.max_size = SIZE_MAX; + + if(cache_params.monitor != NULL) { + if (!strcmp(cache_params.monitor, "userfaultfd") && + default_monitor == uffd_monitor) /* check that userfaultfd supported at all */ + default_monitor = uffd_monitor; + else if (!strcmp(cache_params.monitor, "memhooks")) + default_monitor = memhooks_monitor; + } } void ofi_monitor_cleanup(void) @@ -140,8 +147,10 @@ void ofi_monitor_del_cache(struct ofi_mr_cache *cache) dlist_remove(&cache->notify_entry); if (dlist_empty(&monitor->list)) { - ofi_uffd_cleanup(); - ofi_memhooks_cleanup(); + if (monitor == uffd_monitor) + ofi_uffd_cleanup(); + else if (monitor == memhooks_monitor) + ofi_memhooks_cleanup(); } fastlock_release(&monitor->lock);