diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index 8a55c5866e..b3c355ad65 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -4,6 +4,13 @@ if TL_MLX5_ENABLED +mcast = \ + mcast/tl_mcast_context.c \ + mcast/tl_mcast.h \ + mcast/tl_mcast_coll.c \ + mcast/tl_mcast_coll.h \ + mcast/tl_mcast_team.c + sources = \ tl_mlx5.h \ tl_mlx5.c \ @@ -11,6 +18,8 @@ sources = \ tl_mlx5_context.c \ tl_mlx5_team.c \ tl_mlx5_coll.h \ + tl_mlx5_coll.c \ + $(mcast) \ tl_mlx5_ib.h \ tl_mlx5_ib.c \ tl_mlx5_wqe.h \ diff --git a/src/components/tl/mlx5/mcast/tl_mcast.h b/src/components/tl/mlx5/mcast/tl_mcast.h new file mode 100644 index 0000000000..33ea64fbe6 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mcast.h @@ -0,0 +1,80 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_MCAST_H +#define UCC_MCAST_H + +#include +#include +#include +#include +#include "utils/ucc_list.h" +#include "utils/ucc_mpool.h" +#include "components/tl/ucc_tl.h" +#include "components/tl/ucc_tl_log.h" +#include +#include "utils/ucc_rcache.h" + + +#ifndef UCC_TL_MLX5_MCAST_DEFAULT_SCORE +#define UCC_TL_MLX5_MCAST_DEFAULT_SCORE 30 +#endif + +#define UCC_TL_MLX5_MCAST_SUPPORTED_COLLS (UCC_COLL_TYPE_BCAST) +#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true + +typedef struct mcast_coll_comm_init_spec { +} mcast_coll_comm_init_spec_t; + +typedef struct ucc_tl_mlx5_mcast_lib { +} ucc_tl_mlx5_mcast_lib_t; +UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, + const ucc_base_config_t *); + +typedef struct mcast_ctx_params { +} mcast_ctx_params_t; + +typedef struct mcast_coll_context_t { +} mcast_coll_context_t; + +typedef struct ucc_tl_mlx5_mcast_context_t { +} ucc_tl_mlx5_mcast_context_t; + + +struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */ +}; + +typedef struct mcast_coll_comm mcast_coll_comm_t; + +typedef struct ucc_tl_mlx5_mcast_team { + void *mcast_comm; +} ucc_tl_mlx5_mcast_team_t; + +typedef struct mcast_coll_req { /* Stuff that has to happen per call */ +} mcast_coll_req_t; + +#define TASK_TEAM_MCAST(_task) \ + (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t)) +#define TASK_CTX_MCAST(_task) \ + (ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t)) +#define TASK_LIB_MCAST(_task) \ + (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t)) +#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, + const ucc_base_team_params_t *params, + mcast_coll_comm_init_spec_t *mcast_conf); + +ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); + +ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx, + mcast_ctx_params_t *mcast_ctx_conf); + +#endif diff --git a/src/components/tl/mlx5/mcast/tl_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mcast_coll.c new file mode 100644 index 0000000000..d9bc5503b7 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mcast_coll.c @@ -0,0 +1,72 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_coll.h" + +ucc_status_t mcast_test(mcast_coll_req_t* req) +{ + return UCC_ERR_NOT_SUPPORTED; +} + +ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr, + mcast_coll_comm_t *comm, bool is_blocking, void **task_req_handle) +{ + return UCC_ERR_NOT_SUPPORTED; +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; + ucc_coll_args_t *args = &TASK_ARGS_MCAST(task); + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + ucc_rank_t root = args->root; + ucc_status_t status = UCC_OK; + size_t data_size = ucc_dt_size(dt) * count; + void *buf = args->src.info.buffer; + mcast_coll_comm_t *comm = team->mcast_comm; + task->req_handle = NULL; + + status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, + UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->req_handle); + if (UCC_OK != status && UCC_INPROGRESS != status) { + tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); + coll_task->status = status; + return ucc_task_complete(coll_task); + } + + coll_task->status = status; + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); +} + +void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_status_t completed = UCC_OK; + mcast_coll_req_t *req = task->req_handle; + + if (task->req_handle != NULL) { + completed = mcast_test(task->req_handle); + if (UCC_OK == completed) { + req = task->req_handle; + coll_task->status = UCC_OK; + ucc_free(req); + task->req_handle = NULL; + } + } +} + +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) +{ + + task->super.post = ucc_tl_mlx5_mcast_bcast_start; + task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + + return UCC_OK; +} diff --git a/src/components/tl/mlx5/mcast/tl_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mcast_coll.h new file mode 100644 index 0000000000..cf20f94082 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mcast_coll.h @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_TL_MLX5_MCAST_COLL_H_ +#define UCC_TL_MLX5_MCAST_COLL_H_ + +#include "tl_mcast.h" +#include "tl_mlx5_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); + +ucc_status_t mcast_test(mcast_coll_req_t* _req); + +#endif diff --git a/src/components/tl/mlx5/mcast/tl_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mcast_context.c new file mode 100644 index 0000000000..5336d0cf52 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mcast_context.c @@ -0,0 +1,23 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include +#include "tl_mcast.h" +#include "utils/arch/cpu.h" +#include +#include "src/core/ucc_service_coll.h" +#include "tl_mlx5.h" + +ucc_status_t ucc_tl_mlx5_mcast_runtime_progress(void *mcast_coll_context) +{ + return UCC_ERR_NOT_SUPPORTED; +} + +ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context, + mcast_ctx_params_t *mcast_ctx_conf) +{ + return UCC_ERR_NOT_SUPPORTED; +} diff --git a/src/components/tl/mlx5/mcast/tl_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mcast_team.c new file mode 100644 index 0000000000..c51d4017e7 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mcast_team.c @@ -0,0 +1,20 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + + +#include "tl_mlx5.h" +#include "tl_mcast_coll.h" +#include "coll_score/ucc_coll_score.h" + +ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, + ucc_tl_mlx5_mcast_team_t **mcast_team, + ucc_tl_mlx5_mcast_context_t *ctx, const + ucc_base_team_params_t *params, mcast_coll_comm_init_spec_t + *mcast_conf) +{ + return UCC_ERR_NOT_SUPPORTED; +} + diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index e129f28626..7b328b91e9 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -14,9 +14,10 @@ #include #include #include "utils/arch/cpu.h" +#include "mcast/tl_mcast.h" #ifndef UCC_TL_MLX5_DEFAULT_SCORE -#define UCC_TL_MLX5_DEFAULT_SCORE 1 +#define UCC_TL_MLX5_DEFAULT_SCORE (1 + UCC_TL_MLX5_MCAST_DEFAULT_SCORE) #endif #ifdef HAVE_PROFILING_TL_MLX5 @@ -56,11 +57,13 @@ typedef struct ucc_tl_mlx5_lib_config { unsigned long dm_buf_num; int dm_host; ucc_tl_mlx5_ib_qp_conf_t qp_conf; + mcast_coll_comm_init_spec_t mcast_conf; } ucc_tl_mlx5_lib_config_t; typedef struct ucc_tl_mlx5_context_config { ucc_tl_context_config_t super; ucs_config_names_array_t devices; + mcast_ctx_params_t mcast_ctx_conf; } ucc_tl_mlx5_context_config_t; typedef struct ucc_tl_mlx5_lib { @@ -80,24 +83,26 @@ typedef struct ucc_tl_mlx5_context { int is_imported; int ib_port; ucc_mpool_t req_mp; + ucc_tl_mlx5_mcast_context_t mcast; } ucc_tl_mlx5_context_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t *, const ucc_base_config_t *); typedef struct ucc_tl_mlx5_a2a ucc_tl_mlx5_a2a_t; typedef struct ucc_tl_mlx5_team { - ucc_tl_team_t super; - ucc_service_coll_req_t *scoll_req; - void * oob_req; - ucc_mpool_t dm_pool; - struct ibv_dm * dm_ptr; - struct ibv_mr * dm_mr; - ucc_tl_mlx5_a2a_t * a2a; + ucc_tl_team_t super; + ucc_service_coll_req_t *scoll_req; + void *oob_req; + ucc_mpool_t dm_pool; + struct ibv_dm *dm_ptr; + struct ibv_mr *dm_mr; + ucc_tl_mlx5_a2a_t *a2a; + ucc_tl_mlx5_mcast_team_t *mcast; } ucc_tl_mlx5_team_t; UCC_CLASS_DECLARE(ucc_tl_mlx5_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); -#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL) +#define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL | UCC_TL_MLX5_MCAST_SUPPORTED_COLLS) #define UCC_TL_MLX5_TEAM_LIB(_team) \ (ucc_derived_of((_team)->super.super.context->lib, ucc_tl_mlx5_lib_t)) diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c new file mode 100644 index 0000000000..c36d7545c7 --- /dev/null +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_coll.h" +#include "mcast/tl_mcast_coll.h" + +static ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task); + UCC_TL_MLX5_PROFILE_REQUEST_FREE(task); + ucc_mpool_put(task); + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_bcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; + + if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { + return UCC_ERR_NOT_SUPPORTED; + } + + task = ucc_tl_mlx5_get_task(coll_args, team); + + if (ucc_unlikely(!task)) { + return UCC_ERR_NO_MEMORY; + } + + task->super.finalize = ucc_tl_mlx5_task_finalize; + + status = ucc_tl_mlx5_mcast_bcast_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + + *task_h = &(task->super); + + tl_debug(UCC_TASK_LIB(task), "init coll task %p", task); + + return UCC_OK; + +free_task: + ucc_mpool_put(task); + return status; +} diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index 5505ce2024..c93875f0f1 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -12,6 +12,7 @@ typedef struct ucc_tl_mlx5_task { ucc_coll_task_t super; + void *req_handle; } ucc_tl_mlx5_task_t; typedef struct ucc_tl_mlx5_schedule { @@ -23,15 +24,15 @@ typedef struct ucc_tl_mlx5_schedule { } ucc_tl_mlx5_schedule_t; #define TASK_TEAM(_task) \ - (ucc_derived_of((_task)->super.super.team, ucc_tl_mlx5_team_t)) + (ucc_derived_of((_task)->super.team, ucc_tl_mlx5_team_t)) #define TASK_CTX(_task) \ - (ucc_derived_of((_task)->super.super.team->context, ucc_tl_mlx5_context_t)) + (ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_context_t)) #define TASK_LIB(_task) \ - (ucc_derived_of((_task)->super.super.team->context->lib, ucc_tl_mlx5_lib_t)) + (ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_lib_t)) -#define TASK_ARGS(_task) (_task)->super.super.bargs.args +#define TASK_ARGS(_task) (_task)->super.bargs.args #define TASK_SCHEDULE(_task) \ (ucc_derived_of((_task)->schedule, ucc_tl_mlx5_schedule_t)) @@ -60,6 +61,9 @@ ucc_tl_mlx5_get_schedule(ucc_tl_mlx5_team_t * team, { ucc_tl_mlx5_context_t * ctx = UCC_TL_MLX5_TEAM_CTX(team); ucc_tl_mlx5_schedule_t *schedule = ucc_mpool_get(&ctx->req_mp); + if (ucc_unlikely(!schedule)) { + return NULL; + } UCC_TL_MLX5_PROFILE_REQUEST_NEW(schedule, "tl_mlx5_sched", 0); ucc_schedule_init(&schedule->super, coll_args, &team->super.super); @@ -72,4 +76,8 @@ static inline void ucc_tl_mlx5_put_schedule(ucc_tl_mlx5_schedule_t *schedule) ucc_mpool_put(schedule); } +ucc_status_t ucc_tl_mlx5_bcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h); + #endif diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index df33210ea8..d4c9be995e 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -33,7 +33,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, status = ucc_mpool_init( &self->req_mp, 0, ucc_max(sizeof(ucc_tl_mlx5_task_t), sizeof(ucc_tl_mlx5_schedule_t)), 0, - UCC_CACHE_LINE_SIZE, 8, UINT_MAX, NULL, params->thread_mode, + UCC_CACHE_LINE_SIZE, 8, UINT_MAX, &ucc_coll_task_mpool_ops, params->thread_mode, "tl_mlx5_req_mp"); if (UCC_OK != status) { tl_error(self->super.super.lib, @@ -41,6 +41,13 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_context_t, return status; } + status = ucc_tl_mlx5_mcast_context_init(&(self->mcast), &(self->cfg.mcast_ctx_conf)); + if (UCC_OK != status) { + tl_error(self->super.super.lib, + "failed to initialize mcast context"); + return status; + } + tl_debug(self->super.super.lib, "initialized tl context: %p", self); return UCC_OK; } diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index c1723b52ac..3631b4ed2b 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -4,10 +4,12 @@ * See file LICENSE for terms. */ +#include "tl_mlx5_coll.h" #include "tl_mlx5.h" #include "coll_score/ucc_coll_score.h" #include "core/ucc_team.h" #include +#include "mcast/tl_mcast.h" UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) @@ -20,6 +22,15 @@ UCC_CLASS_INIT_FUNC(ucc_tl_mlx5_team_t, ucc_base_context_t *tl_context, self->a2a = NULL; self->dm_ptr = NULL; + self->mcast = NULL; + + status = ucc_tl_mlx5_mcast_team_init(tl_context, &(self->mcast), &(ctx->mcast), params, + &(UCC_TL_MLX5_TEAM_LIB(self)->cfg.mcast_conf)); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + tl_debug(self->super.super.context->lib, "initialized tl mlx5 team: %p", self); return status; } @@ -42,44 +53,60 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *tl_team) /* NOLINT */ return UCC_OK; } +ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t * team, + ucc_coll_task_t ** task_h) +{ + ucc_status_t status = UCC_OK; + + switch (coll_args->args.coll_type) + { + case UCC_COLL_TYPE_BCAST: + status = ucc_tl_mlx5_bcast_init(coll_args, team, task_h); + break; + default: + status = UCC_ERR_NOT_SUPPORTED; + } + + return status; +} + ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, ucc_coll_score_t **score_p) { ucc_tl_mlx5_team_t *team = ucc_derived_of(tl_team, ucc_tl_mlx5_team_t); ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); - ucc_base_lib_t * lib = UCC_TL_TEAM_LIB(team); - ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST; ucc_coll_score_t * score; ucc_status_t status; - status = ucc_coll_score_alloc(&score); + /* There can be a different logic for different coll_type/mem_type. + Right now just init everything the same way. */ + status = + ucc_coll_score_build_default(tl_team, UCC_TL_MLX5_DEFAULT_SCORE, + ucc_tl_mlx5_coll_init, + UCC_TL_MLX5_SUPPORTED_COLLS, + NULL, 0, &score); if (UCC_OK != status) { - tl_error(lib, "failed to alloc score_t"); return status; } if (strlen(ctx->score_str) > 0) { status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL, tl_team, - UCC_TL_MLX5_DEFAULT_SCORE, NULL, &mt, 1); - + ctx->score_str, score, UCC_TL_TEAM_SIZE(team), + ucc_tl_mlx5_coll_init, &team->super.super, + UCC_TL_MLX5_DEFAULT_SCORE, NULL, NULL, 0); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { goto err; } } + *score_p = score; return UCC_OK; + err: ucc_coll_score_free(score); *score_p = NULL; return status; } - -ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, /* NOLINT */ - ucc_base_team_t * team, /* NOLINT */ - ucc_coll_task_t ** task) /* NOLINT */ -{ - return UCC_OK; -}