Skip to content

Commit

Permalink
v4.1.x: coll/ucc: add reduce scatter block
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Lebedev <sergeyle@nvidia.com>
bot:notacherrypick
  • Loading branch information
Sergei-Lebedev committed Oct 30, 2024
1 parent cbf1a6d commit 5e05bae
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 79 deletions.
31 changes: 16 additions & 15 deletions ompi/mca/coll/ucc/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@

AM_CPPFLAGS = $(coll_ucc_CPPFLAGS)

coll_ucc_sources = \
coll_ucc.h \
coll_ucc_debug.h \
coll_ucc_dtypes.h \
coll_ucc_common.h \
coll_ucc_module.c \
coll_ucc_component.c \
coll_ucc_barrier.c \
coll_ucc_bcast.c \
coll_ucc_allreduce.c \
coll_ucc_reduce.c \
coll_ucc_alltoall.c \
coll_ucc_alltoallv.c \
coll_ucc_allgather.c \
coll_ucc_allgatherv.c
coll_ucc_sources = \
coll_ucc.h \
coll_ucc_debug.h \
coll_ucc_dtypes.h \
coll_ucc_common.h \
coll_ucc_module.c \
coll_ucc_component.c \
coll_ucc_barrier.c \
coll_ucc_bcast.c \
coll_ucc_allreduce.c \
coll_ucc_reduce.c \
coll_ucc_alltoall.c \
coll_ucc_alltoallv.c \
coll_ucc_allgather.c \
coll_ucc_allgatherv.c \
coll_ucc_reduce_scatter_block.c

# Make the output library in this directory, and name it either
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la
Expand Down
96 changes: 57 additions & 39 deletions ompi/mca/coll/ucc/coll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ BEGIN_C_DECLS
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \
UCC_COLL_TYPE_REDUCE_SCATTER)

#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce"
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce,reduce_scatter_block," \
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce,ireduce_scatter_block"

typedef struct mca_coll_ucc_req {
ompi_request_t super;
Expand Down Expand Up @@ -64,42 +65,46 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
* UCC enabled communicator
*/
struct mca_coll_ucc_module_t {
mca_coll_base_module_t super;
ompi_communicator_t* comm;
int rank;
ucc_team_h ucc_team;
mca_coll_base_module_allreduce_fn_t previous_allreduce;
mca_coll_base_module_t* previous_allreduce_module;
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
mca_coll_base_module_t* previous_iallreduce_module;
mca_coll_base_module_reduce_fn_t previous_reduce;
mca_coll_base_module_t* previous_reduce_module;
mca_coll_base_module_ireduce_fn_t previous_ireduce;
mca_coll_base_module_t* previous_ireduce_module;
mca_coll_base_module_barrier_fn_t previous_barrier;
mca_coll_base_module_t* previous_barrier_module;
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
mca_coll_base_module_t* previous_ibarrier_module;
mca_coll_base_module_bcast_fn_t previous_bcast;
mca_coll_base_module_t* previous_bcast_module;
mca_coll_base_module_ibcast_fn_t previous_ibcast;
mca_coll_base_module_t* previous_ibcast_module;
mca_coll_base_module_alltoall_fn_t previous_alltoall;
mca_coll_base_module_t* previous_alltoall_module;
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
mca_coll_base_module_t* previous_ialltoall_module;
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
mca_coll_base_module_t* previous_alltoallv_module;
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
mca_coll_base_module_t* previous_ialltoallv_module;
mca_coll_base_module_allgather_fn_t previous_allgather;
mca_coll_base_module_t* previous_allgather_module;
mca_coll_base_module_iallgather_fn_t previous_iallgather;
mca_coll_base_module_t* previous_iallgather_module;
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
mca_coll_base_module_t* previous_allgatherv_module;
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
mca_coll_base_module_t* previous_iallgatherv_module;
mca_coll_base_module_t super;
ompi_communicator_t* comm;
int rank;
ucc_team_h ucc_team;
mca_coll_base_module_allreduce_fn_t previous_allreduce;
mca_coll_base_module_t* previous_allreduce_module;
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
mca_coll_base_module_t* previous_iallreduce_module;
mca_coll_base_module_reduce_fn_t previous_reduce;
mca_coll_base_module_t* previous_reduce_module;
mca_coll_base_module_ireduce_fn_t previous_ireduce;
mca_coll_base_module_t* previous_ireduce_module;
mca_coll_base_module_barrier_fn_t previous_barrier;
mca_coll_base_module_t* previous_barrier_module;
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
mca_coll_base_module_t* previous_ibarrier_module;
mca_coll_base_module_bcast_fn_t previous_bcast;
mca_coll_base_module_t* previous_bcast_module;
mca_coll_base_module_ibcast_fn_t previous_ibcast;
mca_coll_base_module_t* previous_ibcast_module;
mca_coll_base_module_alltoall_fn_t previous_alltoall;
mca_coll_base_module_t* previous_alltoall_module;
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
mca_coll_base_module_t* previous_ialltoall_module;
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
mca_coll_base_module_t* previous_alltoallv_module;
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
mca_coll_base_module_t* previous_ialltoallv_module;
mca_coll_base_module_allgather_fn_t previous_allgather;
mca_coll_base_module_t* previous_allgather_module;
mca_coll_base_module_iallgather_fn_t previous_iallgather;
mca_coll_base_module_t* previous_iallgather_module;
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
mca_coll_base_module_t* previous_allgatherv_module;
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
mca_coll_base_module_t* previous_iallgatherv_module;
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
mca_coll_base_module_t* previous_reduce_scatter_block_module;
mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block;
mca_coll_base_module_t* previous_ireduce_scatter_block_module;
};
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
Expand Down Expand Up @@ -195,5 +200,18 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_
ompi_request_t** request,
mca_coll_base_module_t *module);

int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
ompi_request_t** request,
mca_coll_base_module_t *module);

END_C_DECLS
#endif
2 changes: 2 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
return UCC_COLL_TYPE_ALLGATHERV;
} else if (0 == strcasecmp(str, "reduce")) {
return UCC_COLL_TYPE_REDUCE;
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
return UCC_COLL_TYPE_REDUCE_SCATTER;
}
UCC_ERROR("incorrect value for cts: %s, allowed: %s",
str, COLL_UCC_CTS_STR);
Expand Down
59 changes: 34 additions & 25 deletions ompi/mca/coll/ucc/coll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@ int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_thread

static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module)
{
ucc_module->ucc_team = NULL;
ucc_module->previous_allreduce = NULL;
ucc_module->previous_iallreduce = NULL;
ucc_module->previous_barrier = NULL;
ucc_module->previous_ibarrier = NULL;
ucc_module->previous_bcast = NULL;
ucc_module->previous_ibcast = NULL;
ucc_module->previous_alltoall = NULL;
ucc_module->previous_ialltoall = NULL;
ucc_module->previous_alltoallv = NULL;
ucc_module->previous_ialltoallv = NULL;
ucc_module->previous_allgather = NULL;
ucc_module->previous_iallgather = NULL;
ucc_module->previous_allgatherv = NULL;
ucc_module->previous_iallgatherv = NULL;
ucc_module->previous_reduce = NULL;
ucc_module->previous_ireduce = NULL;
ucc_module->ucc_team = NULL;
ucc_module->previous_allreduce = NULL;
ucc_module->previous_iallreduce = NULL;
ucc_module->previous_barrier = NULL;
ucc_module->previous_ibarrier = NULL;
ucc_module->previous_bcast = NULL;
ucc_module->previous_ibcast = NULL;
ucc_module->previous_alltoall = NULL;
ucc_module->previous_ialltoall = NULL;
ucc_module->previous_alltoallv = NULL;
ucc_module->previous_ialltoallv = NULL;
ucc_module->previous_allgather = NULL;
ucc_module->previous_iallgather = NULL;
ucc_module->previous_allgatherv = NULL;
ucc_module->previous_iallgatherv = NULL;
ucc_module->previous_reduce = NULL;
ucc_module->previous_ireduce = NULL;
ucc_module->previous_reduce_scatter_block = NULL;
ucc_module->previous_reduce_scatter_block_module = NULL;
ucc_module->previous_ireduce_scatter_block = NULL;
ucc_module->previous_ireduce_scatter_block_module = NULL;
}

static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
Expand Down Expand Up @@ -82,6 +86,8 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iallgatherv_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_scatter_block_module);
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_block_module);
mca_coll_ucc_module_clear(ucc_module);
}

Expand Down Expand Up @@ -113,6 +119,8 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
SAVE_PREV_COLL_API(iallgatherv);
SAVE_PREV_COLL_API(reduce);
SAVE_PREV_COLL_API(ireduce);
SAVE_PREV_COLL_API(reduce_scatter_block);
SAVE_PREV_COLL_API(ireduce_scatter_block);
return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -491,14 +499,15 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
ucc_module->comm = comm;
ucc_module->super.coll_module_enable = mca_coll_ucc_module_enable;
*priority = cm->ucc_priority;
SET_COLL_PTR(ucc_module, BARRIER, barrier);
SET_COLL_PTR(ucc_module, BCAST, bcast);
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
SET_COLL_PTR(ucc_module, REDUCE, reduce);
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
SET_COLL_PTR(ucc_module, BARRIER, barrier);
SET_COLL_PTR(ucc_module, BCAST, bcast);
SET_COLL_PTR(ucc_module, ALLREDUCE, allreduce);
SET_COLL_PTR(ucc_module, ALLTOALL, alltoall);
SET_COLL_PTR(ucc_module, ALLTOALLV, alltoallv);
SET_COLL_PTR(ucc_module, REDUCE, reduce);
SET_COLL_PTR(ucc_module, ALLGATHER, allgather);
SET_COLL_PTR(ucc_module, ALLGATHERV, allgatherv);
SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block);
return &ucc_module->super;
}

Expand Down
117 changes: 117 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
*/

#include "coll_ucc_common.h"

static inline
ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf,
size_t rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_datatype_t ucc_dt;
ucc_reduction_op_t ucc_op;
int comm_size = ompi_comm_size(ucc_module->comm);

if (MPI_IN_PLACE == sbuf) {
/* TODO: UCC defines inplace differently:
data in rbuf of rank R is shifted by R * rcount */
UCC_VERBOSE(5, "inplace reduce_scatter_block is not supported");
return UCC_ERR_NOT_SUPPORTED;
}
ucc_dt = ompi_dtype_to_ucc_dtype(dtype);
ucc_op = ompi_op_to_ucc_op(op);
if (OPAL_UNLIKELY(COLL_UCC_DT_UNSUPPORTED == ucc_dt)) {
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
dtype->super.name);
goto fallback;
}
if (OPAL_UNLIKELY(COLL_UCC_OP_UNSUPPORTED == ucc_op)) {
UCC_VERBOSE(5, "ompi_op is not supported: op = %s",
op->o_name);
goto fallback;
}
ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER,
.src.info = {
.buffer = (void*)sbuf,
.count = ((size_t)rcount) * comm_size,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.dst.info = {
.buffer = rbuf,
.count = rcount,
.datatype = ucc_dt,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
},
.op = ucc_op,
};
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
return UCC_ERR_NOT_SUPPORTED;
}

int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;

UCC_VERBOSE(3, "running ucc reduce scatter block");
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
dtype, op, ucc_module,
&req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback reduce_scatter_block");
return ucc_module->previous_reduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm,
ucc_module->previous_reduce_scatter_block_module);
}

int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
ompi_request_t** request,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
ucc_coll_req_h req;
mca_coll_ucc_req_t *coll_req = NULL;

UCC_VERBOSE(3, "running ucc ireduce_scatter_block");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount,
dtype, op, ucc_module,
&req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback ireduce_scatter_block");
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
}
return ucc_module->previous_ireduce_scatter_block(sbuf, rbuf, rcount, dtype,
op, comm, request,
ucc_module->previous_ireduce_scatter_block_module);
}

0 comments on commit 5e05bae

Please sign in to comment.