Skip to content
This repository was archived by the owner on Sep 30, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,44 @@ int mca_pml_ucx_cleanup(void)
return OMPI_SUCCESS;
}

ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst)
{
ucp_address_t *address;
ucs_status_t status;
size_t addrlen;
ucp_ep_h ep;
int ret;

ompi_proc_t *proc0 = ompi_comm_peer_lookup(comm, 0);
ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, dst);

/* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */
if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx",
&proc0,
dst))) {
return NULL;
}

ret = mca_pml_ucx_recv_worker_address(proc_peer, &address, &addrlen);
if (ret < 0) {
PML_UCX_ERROR("Failed to receive worker address from proc: %d", proc_peer->super.proc_name.vpid);
return NULL;
}

PML_UCX_VERBOSE(2, "connecting to proc. %d", proc_peer->super.proc_name.vpid);
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
free(address);
if (UCS_OK != status) {
PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc_peer->super.proc_name.vpid,
ucs_status_string(status));
return NULL;
}

proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;

return ep;
}

int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
{
ucp_address_t *address;
Expand All @@ -225,6 +263,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
for (i = 0; i < nprocs; ++i) {
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen);
if (ret < 0) {
PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid);
return ret;
}

Expand All @@ -238,7 +277,8 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
free(address);

if (UCS_OK != status) {
PML_UCX_ERROR("Failed to connect");
PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid,
ucs_status_string(status));
return OMPI_ERROR;
}

Expand Down Expand Up @@ -426,7 +466,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
struct ompi_request_t **request)
{
mca_pml_ucx_persistent_request_t *req;

ucp_ep_h ep;

req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs);
if (req == NULL) {
Expand All @@ -436,14 +476,20 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
PML_UCX_TRACE_SEND("isend_init request *%p=%p", buf, count, datatype, dst,
tag, mode, comm, (void*)request, (void*)req)

ep = mca_pml_ucx_get_ep(comm, dst);
if (OPAL_UNLIKELY(NULL == ep)) {
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
return OMPI_ERROR;
}

req->ompi.req_state = OMPI_REQUEST_INACTIVE;
req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
req->buffer = (void *)buf;
req->count = count;
req->datatype = mca_pml_ucx_get_datatype(datatype);
req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
req->send.mode = mode;
req->send.ep = mca_pml_ucx_get_ep(comm, dst);
req->send.ep = ep;

*request = &req->ompi;
return OMPI_SUCCESS;
Expand All @@ -455,13 +501,20 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
struct ompi_request_t **request)
{
ompi_request_t *req;
ucp_ep_h ep;

PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode,
comm, (void*)request)

/* TODO special care to sync/buffered send */

req = (ompi_request_t*)ucp_tag_send_nb(mca_pml_ucx_get_ep(comm, dst), buf, count,
ep = mca_pml_ucx_get_ep(comm, dst);
if (OPAL_UNLIKELY(NULL == ep)) {
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
return OMPI_ERROR;
}

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mca_pml_ucx_send_completion);
Expand All @@ -484,12 +537,19 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
struct ompi_communicator_t* comm)
{
ompi_request_t *req;
ucp_ep_h ep;

PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send");

/* TODO special care to sync/buffered send */

req = (ompi_request_t*)ucp_tag_send_nb(mca_pml_ucx_get_ep(comm, dst), buf, count,
ep = mca_pml_ucx_get_ep(comm, dst);
if (OPAL_UNLIKELY(NULL == ep)) {
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
return OMPI_ERROR;
}

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mca_pml_ucx_send_completion);
Expand Down
2 changes: 2 additions & 0 deletions ompi/mca/pml/ucx/pml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ int mca_pml_ucx_close(void);
int mca_pml_ucx_init(void);
int mca_pml_ucx_cleanup(void);

ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst);
int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs);
int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs);

Expand Down Expand Up @@ -146,4 +147,5 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests);

int mca_pml_ucx_dump(struct ompi_communicator_t* comm, int verbose);


#endif /* PML_UCX_H_ */
7 changes: 6 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ void mca_pml_ucx_request_cleanup(void *request);

static inline ucp_ep_h mca_pml_ucx_get_ep(ompi_communicator_t *comm, int dst)
{
return ompi_comm_peer_lookup(comm, dst)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
ucp_ep_h ep = ompi_comm_peer_lookup(comm,dst)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (OPAL_UNLIKELY(NULL == ep)) {
ep = mca_pml_ucx_add_proc(comm, dst);
}

return ep;
}

static inline void mca_pml_ucx_request_reset(ompi_request_t *req)
Expand Down