diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index a6e4fe661e1..5fc88a9e102 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -415,8 +415,12 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in assert(mca_osc_ucx_component.ucp_worker == NULL); memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = (mca_osc_ucx_component.enable_mpi_threads == true) - ? UCS_THREAD_MODE_MULTI : UCS_THREAD_MODE_SINGLE; + if (mca_osc_ucx_component.enable_mpi_threads) { + worker_params.thread_mode = UCS_THREAD_MODE_MULTI; + } else { + worker_params.thread_mode = + opal_common_ucx_thread_mode(ompi_mpi_thread_provided); + } status = ucp_worker_create(mca_osc_ucx_component.ucp_context, &worker_params, &(mca_osc_ucx_component.ucp_worker)); if (UCS_OK != status) { diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 79e2d149abb..5eb621990c0 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -288,12 +288,12 @@ int mca_pml_ucx_init(int enable_mpi_threads) PML_UCX_VERBOSE(1, "mca_pml_ucx_init"); - /* TODO check MPI thread mode */ params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; if (enable_mpi_threads) { params.thread_mode = UCS_THREAD_MODE_MULTI; } else { - params.thread_mode = UCS_THREAD_MODE_SINGLE; + params.thread_mode = + opal_common_ucx_thread_mode(ompi_mpi_thread_provided); } #if HAVE_DECL_UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK diff --git a/opal/mca/common/ucx/common_ucx.c b/opal/mca/common/ucx/common_ucx.c index 1cd250e8944..b953fcdcf27 100644 --- a/opal/mca/common/ucx/common_ucx.c +++ b/opal/mca/common/ucx/common_ucx.c @@ -25,6 +25,8 @@ #include "opal/util/argv.h" #include "opal/util/printf.h" +#include "mpi.h" + #include #include #include @@ -49,6 +51,23 @@ static void opal_common_ucx_mem_release_cb(void *buf, size_t length, ucm_vm_munmap(buf, length); } +ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode) +{ + switch (ompi_mode) { + case MPI_THREAD_MULTIPLE: + return UCS_THREAD_MODE_MULTI; + case MPI_THREAD_SERIALIZED: + return UCS_THREAD_MODE_SERIALIZED; + case MPI_THREAD_FUNNELED: + case MPI_THREAD_SINGLE: + return UCS_THREAD_MODE_SINGLE; + default: + MCA_COMMON_UCX_WARN("Unknown MPI thread mode %d, using multithread", + ompi_mode); + return UCS_THREAD_MODE_MULTI; + } +} + OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component) { char *default_tls = "rc_verbs,ud_verbs,rc_mlx5,dc_mlx5,ud_mlx5,cuda_ipc,rocm_ipc"; diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index 4b78bc66587..afd322b9add 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -124,6 +124,7 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, s OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *procs, size_t count, size_t my_rank, size_t max_disconnect, ucp_worker_h worker); OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component); +OPAL_DECLSPEC ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode); static inline ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request) diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 21a1d817bc1..33bbda01f27 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -1016,8 +1016,11 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync; params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || options & SHMEM_CTX_PRIVATE || options & SHMEM_CTX_SERIALIZED) { + if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || + oshmem_mpi_thread_provided == SHMEM_THREAD_FUNNELED || options & SHMEM_CTX_PRIVATE) { params.thread_mode = UCS_THREAD_MODE_SINGLE; + } else if (oshmem_mpi_thread_provided == SHMEM_THREAD_SERIALIZED || options & SHMEM_CTX_SERIALIZED) { + params.thread_mode = UCS_THREAD_MODE_SERIALIZED; } else { params.thread_mode = UCS_THREAD_MODE_MULTI; } diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 89c04e2f9eb..41f6581dc94 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -322,6 +322,8 @@ static int spml_ucx_init(void) wkr_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; if (oshmem_mpi_thread_requested == SHMEM_THREAD_MULTIPLE) { wkr_params.thread_mode = UCS_THREAD_MODE_MULTI; + } else if (oshmem_mpi_thread_requested == SHMEM_THREAD_SERIALIZED) { + wkr_params.thread_mode = UCS_THREAD_MODE_SERIALIZED; } else { wkr_params.thread_mode = UCS_THREAD_MODE_SINGLE; }