-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
350 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
#ifndef UCC_MCAST_H | ||
#define UCC_MCAST_H | ||
|
||
#include <infiniband/ib.h> | ||
#include <infiniband/umad.h> | ||
#include <infiniband/verbs.h> | ||
#include <rdma/rdma_verbs.h> | ||
#include "utils/ucc_list.h" | ||
#include "utils/ucc_mpool.h" | ||
#include "components/tl/ucc_tl.h" | ||
#include "components/tl/ucc_tl_log.h" | ||
#include "utils/ucc_rcache.h" | ||
|
||
|
||
#define UCC_TL_MLX5_MCAST_ENABLE_BLOCKING true | ||
|
||
typedef struct mcast_coll_comm_init_spec { | ||
} mcast_coll_comm_init_spec_t; | ||
|
||
typedef struct ucc_tl_mlx5_mcast_lib { | ||
} ucc_tl_mlx5_mcast_lib_t; | ||
UCC_CLASS_DECLARE(ucc_tl_mlx5_mcast_lib_t, const ucc_base_lib_params_t *, | ||
const ucc_base_config_t *); | ||
|
||
typedef struct ucc_tl_mlx5_mcast_ctx_params { | ||
} ucc_tl_mlx5_mcast_ctx_params_t; | ||
|
||
typedef struct mcast_coll_context_t { | ||
} mcast_coll_context_t; | ||
|
||
typedef struct ucc_tl_mlx5_mcast_context_t { | ||
} ucc_tl_mlx5_mcast_context_t; | ||
|
||
|
||
typedef struct mcast_coll_comm { /* Stuff at a per-communicator sort of level */ | ||
} mcast_coll_comm_t; | ||
|
||
typedef struct ucc_tl_mlx5_mcast_team { | ||
void *mcast_comm; | ||
} ucc_tl_mlx5_mcast_team_t; | ||
|
||
typedef struct ucc_tl_mlx5_mcast_coll_req { /* Stuff that has to happen per call */ | ||
} ucc_tl_mlx5_mcast_coll_req_t; | ||
|
||
#define TASK_TEAM_MCAST(_task) \ | ||
(ucc_derived_of((_task)->super.team, ucc_tl_mlx5_mcast_team_t)) | ||
#define TASK_CTX_MCAST(_task) \ | ||
(ucc_derived_of((_task)->super.team->context, ucc_tl_mlx5_mcast_context_t)) | ||
#define TASK_LIB_MCAST(_task) \ | ||
(ucc_derived_of((_task)->super.team->context->lib, ucc_tl_mlx5_mcast_lib_t)) | ||
#define TASK_ARGS_MCAST(_task) (_task)->super.bargs.args | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, | ||
ucc_tl_mlx5_mcast_team_t **mcast_team, | ||
ucc_tl_mlx5_mcast_context_t *ctx, | ||
const ucc_base_team_params_t *params, | ||
mcast_coll_comm_init_spec_t *mcast_conf); | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_coll_init(ucc_base_coll_args_t *coll_args, | ||
ucc_base_team_t *team, | ||
ucc_coll_task_t **task_h); | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *mcast_ctx, | ||
ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
#include "tl_mlx5_coll.h" | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req /* NOLINT */) | ||
{ | ||
return UCC_ERR_NOT_SUPPORTED; | ||
} | ||
|
||
ucc_status_t mcast_coll_do_bcast(void* buf, int size, int root, void *mr, /* NOLINT */ | ||
mcast_coll_comm_t *comm, /* NOLINT */ | ||
int is_blocking, /* NOLINT */ | ||
ucc_tl_mlx5_mcast_coll_req_t **task_req_handle /* NOLINT */) | ||
{ | ||
return UCC_ERR_NOT_SUPPORTED; | ||
} | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) | ||
{ | ||
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); | ||
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); | ||
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; | ||
ucc_coll_args_t *args = &TASK_ARGS_MCAST(task); | ||
ucc_datatype_t dt = args->src.info.datatype; | ||
size_t count = args->src.info.count; | ||
ucc_rank_t root = args->root; | ||
ucc_status_t status = UCC_OK; | ||
size_t data_size = ucc_dt_size(dt) * count; | ||
void *buf = args->src.info.buffer; | ||
mcast_coll_comm_t *comm = team->mcast_comm; | ||
|
||
task->bcast_mcast.req_handle = NULL; | ||
|
||
status = mcast_coll_do_bcast(buf, data_size, root, NULL, comm, | ||
UCC_TL_MLX5_MCAST_ENABLE_BLOCKING, &task->bcast_mcast.req_handle); | ||
if (UCC_OK != status && UCC_INPROGRESS != status) { | ||
tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); | ||
coll_task->status = status; | ||
return ucc_task_complete(coll_task); | ||
} | ||
|
||
coll_task->status = status; | ||
|
||
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); | ||
} | ||
|
||
void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) | ||
{ | ||
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); | ||
ucc_status_t status = UCC_OK; | ||
ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; | ||
|
||
if (req != NULL) { | ||
status = ucc_tl_mlx5_mcast_test(req); | ||
if (UCC_OK == status) { | ||
coll_task->status = UCC_OK; | ||
ucc_free(req); | ||
task->bcast_mcast.req_handle = NULL; | ||
} | ||
} | ||
} | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) | ||
{ | ||
|
||
task->super.post = ucc_tl_mlx5_mcast_bcast_start; | ||
task->super.progress = ucc_tl_mlx5_mcast_collective_progress; | ||
|
||
return UCC_OK; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
#ifndef UCC_TL_MLX5_MCAST_COLL_H_ | ||
#define UCC_TL_MLX5_MCAST_COLL_H_ | ||
|
||
#include "tl_mlx5_mcast.h" | ||
#include "tl_mlx5_coll.h" | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
#include <inttypes.h> | ||
#include "tl_mlx5_mcast.h" | ||
#include "utils/arch/cpu.h" | ||
#include <ucs/sys/string.h> | ||
#include "src/core/ucc_service_coll.h" | ||
#include "tl_mlx5.h" | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *context, /* NOLINT */ | ||
ucc_tl_mlx5_mcast_ctx_params_t *mcast_ctx_conf /* NOLINT */) | ||
{ | ||
return UCC_OK; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
|
||
#include "tl_mlx5.h" | ||
#include "tl_mlx5_mcast_coll.h" | ||
#include "coll_score/ucc_coll_score.h" | ||
|
||
ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, /* NOLINT */ | ||
ucc_tl_mlx5_mcast_team_t **mcast_team, /* NOLINT */ | ||
ucc_tl_mlx5_mcast_context_t *ctx, /* NOLINT */ | ||
const ucc_base_team_params_t *params, /* NOLINT */ | ||
mcast_coll_comm_init_spec_t *mcast_conf /* NOLINT */) | ||
{ | ||
return UCC_OK; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/** | ||
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See file LICENSE for terms. | ||
*/ | ||
|
||
#include "tl_mlx5_coll.h" | ||
#include "mcast/tl_mlx5_mcast_coll.h" | ||
|
||
static ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) | ||
{ | ||
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); | ||
tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task); | ||
UCC_TL_MLX5_PROFILE_REQUEST_FREE(task); | ||
ucc_mpool_put(task); | ||
return UCC_OK; | ||
} | ||
|
||
ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, | ||
ucc_base_team_t * team, | ||
ucc_coll_task_t ** task_h) | ||
{ | ||
ucc_status_t status = UCC_OK; | ||
ucc_tl_mlx5_task_t *task = NULL; | ||
|
||
if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { | ||
return UCC_ERR_NOT_SUPPORTED; | ||
} | ||
|
||
task = ucc_tl_mlx5_get_task(coll_args, team); | ||
|
||
if (ucc_unlikely(!task)) { | ||
return UCC_ERR_NO_MEMORY; | ||
} | ||
|
||
task->super.finalize = ucc_tl_mlx5_task_finalize; | ||
|
||
status = ucc_tl_mlx5_mcast_bcast_init(task); | ||
if (ucc_unlikely(UCC_OK != status)) { | ||
goto free_task; | ||
} | ||
|
||
*task_h = &(task->super); | ||
|
||
tl_debug(UCC_TASK_LIB(task), "init coll task %p", task); | ||
|
||
return UCC_OK; | ||
|
||
free_task: | ||
ucc_mpool_put(task); | ||
return status; | ||
} |
Oops, something went wrong.