Skip to content

Commit

Permalink
TL/NCCL: add cuda managed to score
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed Jun 12, 2023
1 parent fac619e commit a03f2d9
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ ucc_status_t ucc_tl_nccl_coll_init(ucc_base_coll_args_t *coll_args,
ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t **score_p)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team);
ucc_memory_type_t mt = UCC_MEMORY_TYPE_CUDA;
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team);
ucc_memory_type_t mts[2] = {UCC_MEMORY_TYPE_CUDA,
UCC_MEMORY_TYPE_CUDA_MANAGED};
ucc_coll_score_t *score;
ucc_status_t status;
int i;
Expand All @@ -212,16 +213,16 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
team_info.alg_fn = ucc_tl_nccl_alg_id_to_init;
team_info.default_score = UCC_TL_NCCL_DEFAULT_SCORE;
team_info.init = ucc_tl_nccl_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.num_mem_types = 2;
team_info.supported_mem_types = mts;
team_info.supported_colls = UCC_TL_NCCL_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);
/* There can be a different logic for different coll_type/mem_type.
Right now just init everything the same way. */
status =
ucc_coll_score_build_default(tl_team, UCC_TL_NCCL_DEFAULT_SCORE,
ucc_tl_nccl_coll_init, UCC_TL_NCCL_SUPPORTED_COLLS,
&mt, 1, &score);
mts, 2, &score);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
Expand Down

0 comments on commit a03f2d9

Please sign in to comment.