From 8d8cd82c6bcf59cabc85f289b98ba4455ccdce6d Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 1 Jul 2024 12:16:29 -0600 Subject: [PATCH] use THIS_MODULE instead of the provider's name more context is available also incrementing provider module's refcount to prevent unloading when in use --- examples/providers/bsd/provider.c | 4 +-- examples/providers/gpl/provider.c | 4 +-- include/dpusm/provider.h | 9 ++--- include/dpusm/provider_api.h | 8 ++--- src/dpusm.c | 19 +++++------ src/provider.c | 57 ++++++++++++++++++++----------- src/user.c | 4 +-- 7 files changed, 62 insertions(+), 43 deletions(-) diff --git a/examples/providers/bsd/provider.c b/examples/providers/bsd/provider.c index dc9bedf..69ae0a8 100644 --- a/examples/providers/bsd/provider.c +++ b/examples/providers/bsd/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_bsd_provider_init(void) { - const int rc = dpusm_register_bsd(module_name(THIS_MODULE), + const int rc = dpusm_register_bsd(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_bsd_provider_init(void) { static void __exit dpusm_bsd_provider_exit(void) { - dpusm_unregister_bsd(module_name(THIS_MODULE)); + dpusm_unregister_bsd(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/examples/providers/gpl/provider.c b/examples/providers/gpl/provider.c index 89b83c1..821eea2 100644 --- a/examples/providers/gpl/provider.c +++ b/examples/providers/gpl/provider.c @@ -6,7 +6,7 @@ static int __init dpusm_gpl_provider_init(void) { - const int rc = dpusm_register_gpl(module_name(THIS_MODULE), + const int rc = dpusm_register_gpl(THIS_MODULE, &example_dpusm_provider_functions); printk("%s init: %d\n", module_name(THIS_MODULE), rc); return rc; @@ -14,7 +14,7 @@ dpusm_gpl_provider_init(void) { static void __exit dpusm_gpl_provider_exit(void) { - dpusm_unregister_gpl(module_name(THIS_MODULE)); + dpusm_unregister_gpl(THIS_MODULE); printk("%s exit\n", module_name(THIS_MODULE)); } diff --git a/include/dpusm/provider.h b/include/dpusm/provider.h index db092f4..7bfd1c1 100644 --- a/include/dpusm/provider.h +++ b/include/dpusm/provider.h @@ -3,13 +3,14 @@ #include #include +#include #include #include /* single provider data */ typedef struct dpusm_provider_handle { - const char *name; /* reference to a string */ + struct module *module; dpusm_pc_t capabilities; /* constant set of capabilities */ const dpusm_pf_t *funcs; /* reference to a struct */ atomic_t refs; /* how many users are holding this provider */ @@ -22,14 +23,14 @@ typedef struct { size_t count; /* count of registered providers */ rwlock_t lock; atomic_t active; /* how many providers are active (may be larger than count) */ - /* this is not tied to the provider/count */ + /* this is an independent record of the count of all provder handles */ } dpusm_t; -int dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs); +int dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs); /* can't prevent provider module from unloading */ int dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider); -int dpusm_provider_unregister(dpusm_t *dpusm, const char *name); +int dpusm_provider_unregister(dpusm_t *dpusm, struct module *module); dpusm_ph_t **dpusm_provider_get(dpusm_t *dpusm, const char *name); int dpusm_provider_put(dpusm_t *dpusm, void *handle); diff --git a/include/dpusm/provider_api.h b/include/dpusm/provider_api.h index 3ddb238..cb9c9ca 100644 --- a/include/dpusm/provider_api.h +++ b/include/dpusm/provider_api.h @@ -183,10 +183,10 @@ typedef struct dpusm_provider_functions { } dpusm_pf_t; /* returns -ERRNO instead of DPUSM_* */ -int dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_bsd(const char *name); -int dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs); -int dpusm_unregister_gpl(const char *name); +int dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_bsd(struct module *module); +int dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs); +int dpusm_unregister_gpl(struct module *module); /* * call when backing DPU goes down unexpectedly diff --git a/src/dpusm.c b/src/dpusm.c index df1e6bc..9d661e3 100644 --- a/src/dpusm.c +++ b/src/dpusm.c @@ -11,13 +11,13 @@ static dpusm_t dpusm; int -dpusm_register_bsd(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_bsd(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_bsd(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_bsd(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -25,13 +25,13 @@ EXPORT_SYMBOL(dpusm_register_bsd); EXPORT_SYMBOL(dpusm_unregister_bsd); int -dpusm_register_gpl(const char *name, const dpusm_pf_t *funcs) { - return dpusm_provider_register(&dpusm, name, funcs); +dpusm_register_gpl(struct module *module, const dpusm_pf_t *funcs) { + return dpusm_provider_register(&dpusm, module, funcs); } int -dpusm_unregister_gpl(const char *name) { - return dpusm_provider_unregister(&dpusm, name); +dpusm_unregister_gpl(struct module *module) { + return dpusm_provider_unregister(&dpusm, module); } /* provider facing functions */ @@ -76,8 +76,7 @@ dpusm_init(void) { } static void __exit -dpusm_exit(void) -{ +dpusm_exit(void) { dpusm_provider_write_lock(&dpusm); const int active = atomic_read(&dpusm.active); diff --git a/src/provider.c b/src/provider.c index 15ce09a..9b740a5 100644 --- a/src/provider.c +++ b/src/provider.c @@ -84,7 +84,7 @@ find_provider(dpusm_t *dpusm, const char *name) { struct list_head *it = NULL; list_for_each(it, &dpusm->providers) { dpusm_ph_t *dpusmph = list_entry(it, dpusm_ph_t, list); - const char *p_name = dpusmph->name; + const char *p_name = module_name(dpusmph->module); const size_t p_name_len = strlen(p_name); if (name_len == p_name_len) { if (memcmp(name, p_name, p_name_len) == 0) { @@ -108,8 +108,9 @@ static void print_supported(const char *name, const char *func) } static dpusm_ph_t * -dpusmph_init(const char *name, const dpusm_pf_t *funcs) +dpusmph_init(struct module *module, const dpusm_pf_t *funcs) { + const char *name = module_name(module); dpusm_ph_t *dpusmph = dpusm_mem_alloc(sizeof(dpusm_ph_t)); if (dpusmph) { /* fill in capabilities bitmasks */ @@ -223,7 +224,7 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) dpusmph->capabilities.io &= ~DPUSM_IO_DISK; } - dpusmph->name = name; + dpusmph->module = module; dpusmph->funcs = funcs; dpusmph->self = dpusmph; atomic_set(&dpusmph->refs, 0); @@ -234,7 +235,7 @@ dpusmph_init(const char *name, const dpusm_pf_t *funcs) /* add a new provider */ int -dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *funcs) { +dpusm_provider_register(dpusm_t *dpusm, struct module *module, const dpusm_pf_t *funcs) { const int rc = dpusm_provider_sane_at_load(funcs); if (rc != DPUSM_OK) { static const size_t max = @@ -258,7 +259,7 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func printk("%s: DPUSM Provider \"%s\" does not provide " "a valid set of functions. Bad function groups: %s\n", - __func__, name, buf); + __func__, module_name(module), buf); dpusm_mem_free(buf, size); @@ -267,24 +268,33 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func dpusm_provider_write_lock(dpusm); - dpusm_ph_t **found = find_provider(dpusm, name); + dpusm_ph_t **found = find_provider(dpusm, module_name(module)); if (found) { printk("%s: DPUSM Provider with the name \"%s\" (%p) already exists. %zu providers registered.\n", - __func__, name, *found, dpusm->count); + __func__, module_name(module), *found, dpusm->count); dpusm_provider_write_unlock(dpusm); return -EEXIST; } - dpusm_ph_t *provider = dpusmph_init(name, funcs); + dpusm_ph_t *provider = dpusmph_init(module, funcs); if (!provider) { dpusm_provider_write_unlock(dpusm); return -ECANCELED; } + if (!try_module_get(module)) { + printk("%s: DPUSM Provider Error: Could not increment reference count of %s\n", + __func__, module_name(module)); + dpusmph_destroy(provider); + return -ECANCELED; + } + list_add(&provider->list, &dpusm->providers); + dpusm->count++; + printk("%s: DPUSM Provider \"%s\" (%p) added. Now %zu providers registered.\n", - __func__, name, provider, dpusm->count); + __func__, module_name(module), provider, dpusm->count); dpusm_provider_write_unlock(dpusm); @@ -292,8 +302,6 @@ dpusm_provider_register(dpusm_t *dpusm, const char *name, const dpusm_pf_t *func } /* remove provider from list */ -/* can't prevent provider module from unloading */ -/* locking is done by caller */ int dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { if (!provider || !*provider) { @@ -305,7 +313,7 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { const int refs = atomic_read(&(*provider)->refs); if (refs) { printk("%s: Unregistering provider \"%s\" with %d references remaining.\n", - __func__, (*provider)->name, refs); + __func__, module_name((*provider)->module), refs); rc = -EBUSY; } @@ -321,19 +329,19 @@ dpusm_provider_unregister_handle(dpusm_t *dpusm, dpusm_ph_t **provider) { } int -dpusm_provider_unregister(dpusm_t *dpusm, const char *name) { +dpusm_provider_unregister(dpusm_t *dpusm, struct module *module) { dpusm_provider_write_lock(dpusm); - dpusm_ph_t **provider = find_provider(dpusm, name); + dpusm_ph_t **provider = find_provider(dpusm, module_name(module)); if (!provider) { - printk("%s: Could not find provider with name \"%s\"\n", __func__, name); + printk("%s: Could not find provider with name \"%s\"\n", __func__, module_name(module)); dpusm_provider_write_unlock(dpusm); return DPUSM_ERROR; } void *addr = *provider; const int rc = dpusm_provider_unregister_handle(dpusm, provider); - printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, name, addr, rc); + printk("%s: Unregistered \"%s\" (%p): %d\n", __func__, module_name(module), addr, rc); dpusm_provider_write_unlock(dpusm); return rc; @@ -350,10 +358,17 @@ dpusm_provider_get(dpusm_t *dpusm, const char *name) { read_lock(&dpusm->lock); dpusm_ph_t **provider = find_provider(dpusm, name); if (provider) { + /* make sure provider can't be unloaded before user */ + if (!try_module_get((*provider)->module)) { + printk("Error: Could not increment reference count of %s\n", name); + return NULL; + } + atomic_inc(&(*provider)->refs); atomic_inc(&dpusm->active); + printk("%s: User has been given a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, name, *provider, atomic_read(&(*provider)->refs)); if ((*provider)->funcs->at_connect) { (*provider)->funcs->at_connect(); @@ -376,12 +391,15 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { return DPUSM_ERROR; } + struct module *module = (*provider)->module; + if (!atomic_read(&(*provider)->refs)) { printk("%s Error: Cannot decrement provider \"%s\" user count already at 0.\n", - __func__, (*provider)->name); + __func__, module_name(module)); return DPUSM_ERROR; } + module_put(module); atomic_dec(&(*provider)->refs); atomic_dec(&dpusm->active); @@ -392,7 +410,7 @@ dpusm_provider_put(dpusm_t *dpusm, void *handle) { } printk("%s: User has returned a handle to \"%s\" (%p) (now %d users).\n", - __func__, (*provider)->name, *provider, atomic_read(&(*provider)->refs)); + __func__, module_name(module), *provider, atomic_read(&(*provider)->refs)); return DPUSM_OK; } @@ -414,6 +432,7 @@ void dpusm_provider_invalidate(dpusm_t *dpusm, const char *name) { memset(&(*provider)->capabilities, 0, sizeof((*provider)->capabilities)); printk("%s: Provider \"%s\" has been invalidated with %d users active.\n", __func__, name, atomic_read(&(*provider)->refs)); + /* not decrementing module reference count here - provider is still registered */ } else { printk("%s: Error: Did not find provider \"%s\"\n", diff --git a/src/user.c b/src/user.c index ca3b727..52b063a 100644 --- a/src/user.c +++ b/src/user.c @@ -78,7 +78,7 @@ dpusm_provider_sane(dpusm_ph_t **provider) { } if (!FUNCS(provider)) { - printk("Error: Invalidated provider: %s\n", (*provider)->name); + printk("Error: Invalidated provider: %s\n", module_name((*provider)->module)); return DPUSM_PROVIDER_INVALIDATED; } @@ -144,7 +144,7 @@ dpusm_get_provider(const char *name) { static const char * dpusm_get_provider_name(void *provider) { dpusm_ph_t **dpusmph = (dpusm_ph_t **) provider; - return (dpusmph && *dpusmph)?(*dpusmph)->name:NULL; + return (dpusmph && *dpusmph)?module_name((*dpusmph)->module):NULL; } static int