Skip to content

Commit

Permalink
core: Don't create an endpoint per thread
Browse files Browse the repository at this point in the history
Remove the functionality to create an endpoint per thread.  Originally,
this was used to create a full domain/endpoint/cq set of objects per
thread that created a communicator, but that got messy as the rdma
transport created multiple endpoints per thread based on the comm
in use.  So we ended up in a place where we were creating a domain
per thread (sometimes) and an endpoint per thread (always), which was
messy but worked.

With the switch from requesting FI_THREAD_SAFE to FI_THREAD_DOMAIN and
the concurrent switch to domain-level locking for plugin operations,
there really isn't much advantage to the endpoint per thread model,
so this patch removes all that logic.

Signed-off-by: Brian Barrett <bbarrett@amazon.com>
  • Loading branch information
bwbarrett committed Dec 6, 2024
1 parent 46adaec commit 34ddea2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
10 changes: 6 additions & 4 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,12 @@ struct nccl_net_ofi_domain {
int (*create_endpoint)(nccl_net_ofi_domain_t *domain,
nccl_net_ofi_ep_t **ep);

/* hash table of active endpoints. We reuse endpoints based
* on the thread that calls get_ep().
*/
nccl_net_ofi_ep_t *endpoint_table;
/* endpoint used for (at a minimum) receiving connection
messages. Send/Recv protocol uses this for all
communication. The rdma protocol uses this for all tx
requests and all connection-establishment requests, but may
have additional endpoints for the rx side of rdma writes. */
nccl_net_ofi_ep_t *endpoint;

/* thread id of the thread that called get_domain(). Used as
the hash key for the domain hash */
Expand Down
16 changes: 5 additions & 11 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -869,14 +869,11 @@ static int nccl_net_ofi_domain_get_ep(nccl_net_ofi_domain_t *domain,
nccl_net_ofi_ep_t **ep_p)
{
int ret = 0;
long thread_id;
nccl_net_ofi_ep_t *ep = NULL;

nccl_net_ofi_mutex_lock(&domain->domain_lock);

thread_id = nccl_net_ofi_gettid();
HASH_FIND(hh, domain->endpoint_table, &thread_id,
sizeof(ep->creating_thread_id), ep);
ep = domain->endpoint;

if (ep == NULL) {
ret = domain->create_endpoint(domain, &ep);
Expand All @@ -886,10 +883,7 @@ static int nccl_net_ofi_domain_get_ep(nccl_net_ofi_domain_t *domain,
goto unlock;
}

ep->creating_thread_id = thread_id;

HASH_ADD(hh, domain->endpoint_table, creating_thread_id,
sizeof(ep->creating_thread_id), ep);
domain->endpoint = ep;

NCCL_OFI_TRACE(NCCL_NET, "Eendpoint %p for domain %p is created",
ep, domain);
Expand All @@ -915,7 +909,7 @@ static int nccl_net_ofi_domain_release(nccl_net_ofi_domain_t *domain)

nccl_net_ofi_mutex_lock(&domain->domain_lock);

if (HASH_COUNT(domain->endpoint_table) == 0) {
if (domain->endpoint == NULL) {
nccl_net_ofi_mutex_lock(&device->device_lock);
HASH_DEL(device->domain_table, domain);

Expand Down Expand Up @@ -954,7 +948,7 @@ int nccl_net_ofi_domain_init(nccl_net_ofi_device_t *device, nccl_net_ofi_domain_

domain->get_ep = nccl_net_ofi_domain_get_ep;
domain->release = nccl_net_ofi_domain_release;
domain->endpoint_table = NULL;
domain->endpoint = NULL;
domain->creating_thread_id = 0;

domain->mr_cache = NULL;
Expand Down Expand Up @@ -1022,7 +1016,7 @@ int nccl_net_ofi_endpoint_release(nccl_net_ofi_ep_t *ep)
ep->ref_cnt--;

if (ep->ref_cnt == 0) {
HASH_DEL(domain->endpoint_table, ep);
domain->endpoint = NULL;

ret = ep->free_ep(ep);
if (ret != 0) {
Expand Down

0 comments on commit 34ddea2

Please sign in to comment.