diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 45819da272..70108a54af 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -4,6 +4,13 @@ if TL_MLX5_ENABLED +mcast = \ + mcast/tl_mlx5_mcast_context.c \ + mcast/tl_mlx5_mcast.h \ + mcast/tl_mlx5_mcast_coll.c \ + mcast/tl_mlx5_mcast_coll.h \ + mcast/tl_mlx5_mcast_team.c + sources = \ tl_mlx5.h \ tl_mlx5.c \ @@ -11,6 +18,8 @@ sources = \ tl_mlx5_context.c \ tl_mlx5_team.c \ tl_mlx5_coll.h \ + tl_mlx5_coll.c \ + $(mcast) \ tl_mlx5_ib.h \ tl_mlx5_ib.c \ tl_mlx5_wqe.h \ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h new file mode 100644 index 0000000000..fa666612d1 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -0,0 +1,72 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_MCAST_H +#define UCC_MCAST_H + +#include +#include +#include +#include +#include "utils/ucc_list.h" +#include "utils/ucc_mpool.h" +#include "components/tl/ucc_tl.h" +#include "components/tl/ucc_tl_log.h" +#include "utils/ucc_rcache.h" + + +#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true + +typedef struct mcast_coll_comm_init_spec { +} mcast_coll_comm_init_spec_t; + +typedef struct ucc_tl_mlx5_mcast_lib { +} ucc_tl_mlx5_mcast_lib_t; +UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, + const ucc_base_config_t *); + +typedef struct ucc_tl_mlx5_mcast_ctx_params { +} ucc_tl_mlx5_mcast_ctx_params_t; + +typedef struct mcast_coll_context_t { +} mcast_coll_context_t; + +typedef struct ucc_tl_mlx5_mcast_context_t { +} ucc_tl_mlx5_mcast_context_t; + + +typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */ +} mcast_coll_comm_t; + +typedef struct ucc_tl_mlx5_mcast_team { + void *mcast_comm; +} ucc_tl_mlx5_mcast_team_t; + +typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */ +} ucc_tl_mlx5_mcast_coll_req_t; + +#define TASK_TEAM_MCAST(_task) \ + (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t)) +#define TASK_CTX_MCAST(_task) \ + (ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t)) +#define TASK_LIB_MCAST(_task) \ + (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t)) +#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, + const ucc_base_team_params_t *params, + mcast_coll_comm_init_spec_t *mcast_conf); + +ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx, + ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf); + +#endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c new file mode 100644 index 0000000000..54cdfb267d --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */) +{ + return UCC_ERR_NOT_SUPPORTED; +} + +ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr, /* NOLINT */ + mcast_coll_comm_t *comm, /* NOLINT */ + int is_blocking, /* NOLINT */ + ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */) +{ + return UCC_ERR_NOT_SUPPORTED; +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; + ucc_coll_args_t *args = &TASK_ARGS_MCAST(task); + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + ucc_rank_t root = args->root; + ucc_status_t status = UCC_OK; + size_t data_size = ucc_dt_size(dt) * count; + void *buf = args->src.info.buffer; + mcast_coll_comm_t *comm = team->mcast_comm; + + task->bcast_mcast.req_handle = NULL; + + status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, + UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle); + if (UCC_OK != status && UCC_INPROGRESS != status) { + tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); + coll_task->status = status; + return ucc_task_complete(coll_task); + } + + coll_task->status = status; + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); +} + +void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + + if (req != NULL) { + status = ucc_tl_mlx5_mcast_test(req); + if (UCC_OK == status) { + coll_task->status = UCC_OK; + ucc_free(req); + task->bcast_mcast.req_handle = NULL; + } + } +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) +{ + + task->super.post = ucc_tl_mlx5_mcast_bcast_start; + task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + + return UCC_OK; +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h new file mode 100644 index 0000000000..47ddc301aa --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_TL_MLX5_MCAST_COLL_H_ +#define UCC_TL_MLX5_MCAST_COLL_H_ + +#include "tl_mlx5_mcast.h" +#include "tl_mlx5_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); + +ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); + +#endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c new file mode 100644 index 0000000000..fba3d2c0ab --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -0,0 +1,18 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include +#include "tl_mlx5_mcast.h" +#include "utils/arch/cpu.h" +#include +#include "src/core/ucc_service_coll.h" +#include "tl_mlx5.h" + +ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context, /* NOLINT */ + ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf /* NOLINT */) +{ + return UCC_ERR_NOT_SUPPORTED; +} diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c new file mode 100644 index 0000000000..f9b9cdafc3 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -0,0 +1,20 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + + +#include "tl_mlx5.h" +#include "tl_mlx5_mcast_coll.h" +#include "coll_score/ucc_coll_score.h" + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, /* NOLINT */ + ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ + ucc_tl_mlx5_mcast_context_t *ctx, /* NOLINT */ + const ucc_base_team_params_t *params, /* NOLINT */ + mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) +{ + return UCC_ERR_NOT_SUPPORTED; +} + diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 8c51990289..16803ec8da 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -14,6 +14,7 @@ #include #include #include "utils/arch/cpu.h" +#include "mcast/tl_mlx5_mcast.h" #ifndef UCC_TL_MLX5_DEFAULT_SCORE #define UCC_TL_MLX5_DEFAULT_SCORE 1 @@ -48,19 +49,21 @@ typedef struct ucc_tl_mlx5_ib_qp_conf { } ucc_tl_mlx5_ib_qp_conf_t; typedef struct ucc_tl_mlx5_lib_config { - ucc_tl_lib_config_t super; - int block_size; - int num_dci_qps; - int dc_threshold; - size_t dm_buf_size; - unsigned long dm_buf_num; - int dm_host; - ucc_tl_mlx5_ib_qp_conf_t qp_conf; + ucc_tl_lib_config_t super; + int block_size; + int num_dci_qps; + int dc_threshold; + size_t dm_buf_size; + unsigned long dm_buf_num; + int dm_host; + ucc_tl_mlx5_ib_qp_conf_t qp_conf; + mcast_coll_comm_init_spec_t mcast_conf; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { - ucc_tl_context_config_t super; - ucs_config_names_array_t devices; + ucc_tl_context_config_t super; + ucs_config_names_array_t devices; + ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf; } ucc_tl_mlx5_context_config_t; typedef struct ucc_tl_mlx5_lib { @@ -80,6 +83,7 @@ typedef struct ucc_tl_mlx5_context { int is_imported; int ib_port; ucc_mpool_t req_mp; + ucc_tl_mlx5_mcast_context_t mcast; } ucc_tl_mlx5_context_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t *, const ucc_base_config_t *); @@ -100,17 +104,18 @@ typedef enum } ucc_tl_mlx5_team_state_t; typedef struct ucc_tl_mlx5_team { - ucc_tl_team_t super; - ucc_status_t status[2]; - ucc_service_coll_req_t *scoll_req; - ucc_tl_mlx5_team_state_t state; - void *dm_offset; - ucc_mpool_t dm_pool; - struct ibv_dm *dm_ptr; - struct ibv_mr *dm_mr; - ucc_tl_mlx5_a2a_t *a2a; - ucc_topo_t *topo; - ucc_ep_map_t ctx_map; + ucc_tl_team_t super; + ucc_status_t status[2]; + ucc_service_coll_req_t *scoll_req; + ucc_tl_mlx5_team_state_t state; + void *dm_offset; + ucc_mpool_t dm_pool; + struct ibv_dm *dm_ptr; + struct ibv_mr *dm_mr; + ucc_tl_mlx5_a2a_t *a2a; + ucc_topo_t *topo; + ucc_ep_map_t ctx_map; + ucc_tl_mlx5_mcast_team_t *mcast; } ucc_tl_mlx5_team_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); @@ -126,7 +131,7 @@ typedef struct ucc_tl_mlx5_rcache_region { ucc_tl_mlx5_reg_t reg; } ucc_tl_mlx5_rcache_region_t; -#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL) +#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_BCAST) #define UCC_TL_MLX5_TEAM_LIB(_team) \ (ucc_derived_of((_team)->super.super.context->lib, ucc_tl_mlx5_lib_t)) diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c new file mode 100644 index 0000000000..5d14d31d86 --- /dev/null +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_coll.h" +#include "mcast/tl_mlx5_mcast_coll.h" + +static ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task); + UCC_TL_MLX5_PROFILE_REQUEST_FREE(task); + ucc_mpool_put(task); + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; + + if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { + return UCC_ERR_NOT_SUPPORTED; + } + + task = ucc_tl_mlx5_get_task(coll_args, team); + + if (ucc_unlikely(!task)) { + return UCC_ERR_NO_MEMORY; + } + + task->super.finalize = ucc_tl_mlx5_task_finalize; + + status = ucc_tl_mlx5_mcast_bcast_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + + *task_h = &(task->super); + + tl_debug(UCC_TASK_LIB(task), "init coll task %p", task); + + return UCC_OK; + +free_task: + ucc_mpool_put(task); + return status; +} diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 5505ce2024..f0170156b7 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -12,6 +12,11 @@ typedef struct ucc_tl_mlx5_task { ucc_coll_task_t super; + union { + struct { + ucc_tl_mlx5_mcast_coll_req_t *req_handle; + } bcast_mcast; + }; } ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule { @@ -23,15 +28,15 @@ typedef struct ucc_tl_mlx5_schedule { } ucc_tl_mlx5_schedule_t; #define TASK_TEAM(_task) \ - (ucc_derived_of((_task)->super.super.team, ucc_tl_mlx5_team_t)) + (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_team_t)) #define TASK_CTX(_task) \ - (ucc_derived_of((_task)->super.super.team->context, ucc_tl_mlx5_context_t)) + (ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_context_t)) #define TASK_LIB(_task) \ - (ucc_derived_of((_task)->super.super.team->context->lib, ucc_tl_mlx5_lib_t)) + (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_lib_t)) -#define TASK_ARGS(_task) (_task)->super.super.bargs.args +#define TASK_ARGS(_task) (_task)->super.bargs.args #define TASK_SCHEDULE(_task) \ (ucc_derived_of((_task)->schedule, ucc_tl_mlx5_schedule_t)) @@ -61,6 +66,10 @@ ucc_tl_mlx5_get_schedule(ucc_tl_mlx5_team_t * team, ucc_tl_mlx5_context_t * ctx = UCC_TL_MLX5_TEAM_CTX(team); ucc_tl_mlx5_schedule_t *schedule = ucc_mpool_get(&ctx->req_mp); + if (ucc_unlikely(!schedule)) { + return NULL; + } + UCC_TL_MLX5_PROFILE_REQUEST_NEW(schedule, "tl_mlx5_sched", 0); ucc_schedule_init(&schedule->super, coll_args, &team->super.super); return schedule; @@ -72,4 +81,8 @@ static inline void ucc_tl_mlx5_put_schedule(ucc_tl_mlx5_schedule_t *schedule) ucc_mpool_put(schedule); } +ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h); + #endif diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index f9cabb99ba..09c51946d6 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -41,6 +41,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, return status; } + status = ucc_tl_mlx5_mcast_context_init(&(self->mcast), &(self->cfg.mcast_ctx_conf)); + if (UCC_OK != status) { + tl_error(self->super.super.lib, + "failed to initialize mcast context"); + return status; + } + tl_debug(self->super.super.lib, "initialized tl context: %p", self); return UCC_OK; } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index bb109b45f2..f9556dd989 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -4,11 +4,13 @@ * See file LICENSE for terms. */ +#include "tl_mlx5_coll.h" #include "tl_mlx5.h" #include "tl_mlx5_dm.h" #include "coll_score/ucc_coll_score.h" #include "core/ucc_team.h" #include +#include "mcast/tl_mlx5_mcast.h" static ucc_status_t ucc_tl_mlx5_topo_init(ucc_tl_mlx5_team_t *team) { @@ -68,6 +70,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, } } + self->mcast = NULL; + status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params, + &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + self->status[0] = status; self->state = TL_MLX5_TEAM_STATE_INIT; @@ -137,12 +146,29 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) return UCC_OK; } +ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h) +{ + ucc_status_t status = UCC_OK; + + switch (coll_args->args.coll_type) + { + case UCC_COLL_TYPE_BCAST: + status = ucc_tl_mlx5_bcast_mcast_init(coll_args, team, task_h); + break; + default: + status = UCC_ERR_NOT_SUPPORTED; + } + + return status; +} + ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, ucc_coll_score_t **score_p) { ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team, ucc_tl_mlx5_team_t); ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); - ucc_base_lib_t * lib = UCC_TL_TEAM_LIB(team); ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST; ucc_coll_score_t * score; ucc_status_t status; @@ -156,9 +182,14 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, team_info.supported_colls = UCC_TL_MLX5_SUPPORTED_COLLS; team_info.size = UCC_TL_TEAM_SIZE(team); - status = ucc_coll_score_alloc(&score); + /* There can be a different logic for different coll_type/mem_type. + Right now just init everything the same way. */ + status = + ucc_coll_score_build_default(tl_team, UCC_TL_MLX5_DEFAULT_SCORE, + ucc_tl_mlx5_coll_init, + UCC_TL_MLX5_SUPPORTED_COLLS, + NULL, 0, &score); if (UCC_OK != status) { - tl_error(lib, "failed to alloc score_t"); return status; } @@ -171,17 +202,12 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, goto err; } } + *score_p = score; return UCC_OK; + err: ucc_coll_score_free(score); *score_p = NULL; return status; } - -ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, /* NOLINT */ - ucc_base_team_t * team, /* NOLINT */ - ucc_coll_task_t ** task) /* NOLINT */ -{ - return UCC_OK; -}