diff --git a/include/dpusm/provider.h b/include/dpusm/provider.h index 1e1c1fd..cd90443 100644 --- a/include/dpusm/provider.h +++ b/include/dpusm/provider.h @@ -4,7 +4,7 @@ #include <linux/atomic.h> #include <linux/list.h> #include <linux/module.h> -#include <linux/spinlock.h> +#include <linux/mutex.h> #include <dpusm/provider_api.h> @@ -20,7 +20,7 @@ typedef struct dpusm_provider_handle { typedef struct { struct list_head providers; /* list of providers */ size_t count; /* count of registered providers */ - rwlock_t lock; + struct mutex lock; atomic_t active; /* how many providers are active (may be larger than count) */ /* this is not tied to the provider/count */ } dpusm_t; @@ -34,9 +34,6 @@ int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module); dpusm_ph_t **dpusm_provider_get(dpusm_t *dpusm, const char *name); int dpusm_provider_put(dpusm_t *dpusm, void *handle); -void dpusm_provider_write_lock(dpusm_t *dpusm); -void dpusm_provider_write_unlock(dpusm_t *dpusm); - /* * call when backing DPU goes down unexpectedly * diff --git a/src/dpusm.c b/src/dpusm.c index 9d661e3..712f6f7 100644 --- a/src/dpusm.c +++ b/src/dpusm.c @@ -66,7 +66,7 @@ static int __init dpusm_init(void) { INIT_LIST_HEAD(&dpusm.providers); dpusm.count = 0; - rwlock_init(&dpusm.lock); + mutex_init(&dpusm.lock); atomic_set(&dpusm.active, 0); dpusm_mem_init(); @@ -77,7 +77,7 @@ dpusm_init(void) { static void __exit dpusm_exit(void) { - dpusm_provider_write_lock(&dpusm); + mutex_lock(&dpusm.lock); const int active = atomic_read(&dpusm.active); if (unlikely(active)) { @@ -100,7 +100,7 @@ dpusm_exit(void) { dpusm_provider_unregister_handle(&dpusm, &provider->self); } - dpusm_provider_write_unlock(&dpusm); + mutex_unlock(&dpusm.lock); #if DPUSM_TRACK_ALLOCS size_t alloc_count = 0; diff --git a/src/provider.c b/src/provider.c index 19572a9..0b28c44 100644 --- a/src/provider.c +++ b/src/provider.c @@ -265,19 +265,19 @@ dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t return -EINVAL; } - dpusm_provider_write_lock(dpusm); + mutex_lock(&dpusm->lock);; dpusm_ph_t **found = find_provider(dpusm, module_name(module)); if (found) { printk("%s: DPUSM Provider with the name \"%s\" (%p) already exists. %zu providers registered.\n", __func__, module_name(module), *found, dpusm->count); - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; return -EEXIST; } dpusm_ph_t *provider = dpusmph_init(module, funcs); if (!provider) { - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; return -ECANCELED; } @@ -286,7 +286,7 @@ dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t printk("%s: DPUSM Provider \"%s\" (%p) added. Now %zu providers registered.\n", __func__, module_name(module), provider, dpusm->count); - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; return 0; } @@ -320,12 +320,12 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { - dpusm_provider_write_lock(dpusm); + mutex_lock(&dpusm->lock);; dpusm_ph_t **provider = find_provider(dpusm, module_name(module)); if (!provider) { printk("%s: Could not find provider with name \"%s\"\n", __func__, module_name(module)); - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; return DPUSM_ERROR; } @@ -333,7 +333,7 @@ dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { const int rc = dpusm_provider_unregister_handle(dpusm, provider); printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, module_name(module), addr, rc); - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; return rc; } @@ -345,7 +345,7 @@ dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { /* get a provider by name */ dpusm_ph_t ** dpusm_provider_get(dpusm_t *dpusm, const char *name) { - read_lock(&dpusm->lock); + mutex_lock(&dpusm->lock); dpusm_ph_t **provider = find_provider(dpusm, name); if (provider) { struct module *module = (*provider)->module; @@ -404,18 +404,8 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { return DPUSM_OK; } -void -dpusm_provider_write_lock(dpusm_t *dpusm) { - write_lock(&dpusm->lock); -} - -void -dpusm_provider_write_unlock(dpusm_t *dpusm) { - write_unlock(&dpusm->lock); -} - void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) { - dpusm_provider_write_lock(dpusm); + mutex_lock(&dpusm->lock);; dpusm_ph_t **provider = find_provider(dpusm, name); if (provider && *provider) { (*provider)->funcs = NULL; @@ -428,5 +418,5 @@ void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) { printk("%s: Error: Did not find provider \"%s\"\n", __func__, name); } - dpusm_provider_write_unlock(dpusm); + mutex_unlock(&dpusm->lock);; }