Skip to content

osc/ucx: Make wpctx global #8212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
typedef struct ompi_osc_ucx_component {
ompi_osc_base_component_t super;
opal_common_ucx_wpool_t *wpool;
opal_common_ucx_ctx_t *wpctx;
bool enable_mpi_threads;
opal_free_list_t requests; /* request free list for the r* communication variants */
bool env_initialized; /* UCX environment is initialized or not */
Expand Down
6 changes: 5 additions & 1 deletion ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ static inline int end_atomicity(
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret);
return OMPI_ERROR;
}
opal_atomic_wmb();

return ret;
}
Expand Down Expand Up @@ -1031,7 +1032,6 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
ompi_request_complete(&ucx_req->super, true);
}


return end_atomicity(module, target, lock_acquired, free_addr);
}

Expand Down Expand Up @@ -1096,6 +1096,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
return ret;
}

opal_atomic_wmb();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need a write memory barrier at the end of ompi_osc_ucx_rput?

*request = &ucx_req->super;

return ret;
Expand Down Expand Up @@ -1149,6 +1150,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
return ret;
}

opal_atomic_wmb();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: why a memory barrier at the end of rget? (the operation is still pending after all)

*request = &ucx_req->super;

return ret;
Expand Down Expand Up @@ -1178,6 +1180,7 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
return ret;
}

opal_atomic_wmb();
*request = &ucx_req->super;

return ret;
Expand Down Expand Up @@ -1212,6 +1215,7 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
return ret;
}

opal_atomic_wmb();
*request = &ucx_req->super;

return ret;
Expand Down
75 changes: 52 additions & 23 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = {
.osc_finalize = component_finalize,
},
.wpool = NULL,
.wpctx = NULL,
.env_initialized = false,
.num_incomplete_req_ops = 0,
.num_modules = 0,
Expand Down Expand Up @@ -192,13 +193,15 @@ static int progress_callback(void) {
static int component_init(bool enable_progress_threads, bool enable_mpi_threads) {
mca_osc_ucx_component.enable_mpi_threads = enable_mpi_threads;
mca_osc_ucx_component.wpool = opal_common_ucx_wpool_allocate();
mca_osc_ucx_component.wpctx = opal_common_ucx_wpctx_allocate();
opal_common_ucx_mca_register();
return OMPI_SUCCESS;
}

static int component_finalize(void) {
opal_common_ucx_mca_deregister();
if (mca_osc_ucx_component.env_initialized) {
opal_common_ucx_wpctx_release(mca_osc_ucx_component.wpctx);
opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool);
}
opal_common_ucx_wpool_free(mca_osc_ucx_component.wpool);
Expand All @@ -211,20 +214,42 @@ static int component_query(struct ompi_win_t *win, void **base, size_t size, int
return mca_osc_ucx_component.priority;
}

static unsigned int get_proc_vpid(void *metadata, int rank)
{
struct ompi_communicator_t *comm = (struct ompi_communicator_t *)metadata;
ompi_group_t *group = comm->c_local_group;
opal_process_name_t tmp;

/* find the processor of the destination */
ompi_proc_t *proc = ompi_group_get_proc_ptr(group, rank, true);

if( ompi_proc_is_sentinel(proc) ) {
tmp = ompi_proc_sentinel_to_name((uintptr_t)proc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation (function and this block)

} else {
tmp = proc->super.proc_name;
}

return tmp.vpid;
}

static int exchange_len_info(void *my_info, size_t my_info_len, char **recv_info_ptr,
int **disps_ptr, void *metadata)
int **disps_ptr, int **lens_ptr, void *metadata)
{
int ret = OMPI_SUCCESS;
struct ompi_communicator_t *comm = (struct ompi_communicator_t *)metadata;
int comm_size = ompi_comm_size(comm);
int lens[comm_size];
int total_len, i;
int *lens = calloc(comm_size, sizeof(*lens));

if (NULL != lens_ptr) {
*lens_ptr = lens;
}

ret = comm->c_coll->coll_allgather(&my_info_len, 1, MPI_INT,
lens, 1, MPI_INT, comm,
comm->c_coll->coll_allgather_module);
if (OMPI_SUCCESS != ret) {
return ret;
goto fini;
}

total_len = 0;
Expand All @@ -233,13 +258,14 @@ static int exchange_len_info(void *my_info, size_t my_info_len, char **recv_info
(*disps_ptr)[i] = total_len;
total_len += lens[i];
}

(*recv_info_ptr) = (char *)calloc(total_len, sizeof(char));
ret = comm->c_coll->coll_allgatherv(my_info, my_info_len, MPI_BYTE,
(void *)(*recv_info_ptr), lens, (*disps_ptr), MPI_BYTE,
comm, comm->c_coll->coll_allgatherv_module);
if (OMPI_SUCCESS != ret) {
return ret;

fini:
if (NULL == lens_ptr) {
free(lens);
}

return ret;
Expand Down Expand Up @@ -303,7 +329,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
char *name = NULL;
long values[2];
int ret = OMPI_SUCCESS;
//ucs_status_t status;
int i, comm_size = ompi_comm_size(comm);
bool env_initialized = false;
void *state_base = NULL;
Expand Down Expand Up @@ -348,6 +373,13 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto select_unlock;
}

ret = opal_common_ucx_wpctx_init(mca_osc_ucx_component.wpool,
mca_osc_ucx_component.wpctx,
&get_proc_vpid);
if (OMPI_SUCCESS != ret) {
goto select_unlock;
}

/* Make sure that all memory updates performed above are globally
* observable before (mca_osc_ucx_component.env_initialized = true)
*/
Expand Down Expand Up @@ -432,10 +464,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
}
}

ret = opal_common_ucx_wpctx_create(mca_osc_ucx_component.wpool, comm_size,
&exchange_len_info, (void *)module->comm,
&module->ctx);
if (OMPI_SUCCESS != ret) {
/* Populate addr table */
ret = opal_common_ucx_wpool_update_addr(mca_osc_ucx_component.wpool, comm_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to check whether all addresses have been exchanged on that particular communicator and avoid the exchange after the first window has been allocated?

&exchange_len_info, &get_proc_vpid, (void *)module->comm);
if (ret != OMPI_SUCCESS) {
goto error;
}

Expand All @@ -449,24 +481,23 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
break;
}

ret = opal_common_ucx_wpmem_create(module->ctx, base, size,
mem_type, &exchange_len_info,
(void *)module->comm,
ret = opal_common_ucx_wpmem_create(mca_osc_ucx_component.wpctx, base, size,
mem_type, &exchange_len_info,
(void *)module->comm, comm_size,
&my_mem_addr, &my_mem_addr_size,
&module->mem);
if (ret != OMPI_SUCCESS) {
goto error;
}

}

state_base = (void *)&(module->state);
ret = opal_common_ucx_wpmem_create(module->ctx, &state_base,
ret = opal_common_ucx_wpmem_create(mca_osc_ucx_component.wpctx, &state_base,
sizeof(ompi_osc_ucx_state_t),
OPAL_COMMON_UCX_MEM_MAP, &exchange_len_info,
(void *)module->comm,
&my_mem_addr, &my_mem_addr_size,
&module->state_mem);
(void *)module->comm, comm_size,
&my_mem_addr, &my_mem_addr_size,
&module->state_mem);
if (ret != OMPI_SUCCESS) {
goto error;
}
Expand Down Expand Up @@ -613,9 +644,9 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) {
insert_index = 0;
}

ret = opal_common_ucx_wpmem_create(module->ctx, &base, len,
ret = opal_common_ucx_wpmem_create(mca_osc_ucx_component.wpctx, &base, len,
OPAL_COMMON_UCX_MEM_MAP, &exchange_len_info,
(void *)module->comm,
(void *)module->comm, ompi_comm_size(module->comm),
&(module->local_dynamic_win_info[insert_index].my_mem_addr),
&(module->local_dynamic_win_info[insert_index].my_mem_addr_size),
&(module->local_dynamic_win_info[insert_index].mem));
Expand Down Expand Up @@ -693,8 +724,6 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
opal_common_ucx_wpmem_free(module->state_mem);
opal_common_ucx_wpmem_free(module->mem);

opal_common_ucx_wpctx_release(module->ctx);

if (module->disp_units) {
free(module->disp_units);
}
Expand Down
Loading