diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index b2be47c074..ea6040a146 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -201,9 +201,10 @@ ucc_status_t ucc_tl_nccl_coll_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, ucc_coll_score_t **score_p) { - ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); - ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); - ucc_memory_type_t mt = UCC_MEMORY_TYPE_CUDA; + ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); + ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); + ucc_memory_type_t mts[2] = {UCC_MEMORY_TYPE_CUDA, + UCC_MEMORY_TYPE_CUDA_MANAGED}; ucc_coll_score_t *score; ucc_status_t status; int i; @@ -212,8 +213,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, team_info.alg_fn = ucc_tl_nccl_alg_id_to_init; team_info.default_score = UCC_TL_NCCL_DEFAULT_SCORE; team_info.init = ucc_tl_nccl_coll_init; - team_info.num_mem_types = 1; - team_info.supported_mem_types = &mt; + team_info.num_mem_types = 2; + team_info.supported_mem_types = mts; team_info.supported_colls = UCC_TL_NCCL_SUPPORTED_COLLS; team_info.size = UCC_TL_TEAM_SIZE(team); /* There can be a different logic for different coll_type/mem_type. @@ -221,7 +222,7 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, status = ucc_coll_score_build_default(tl_team, UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_coll_init, UCC_TL_NCCL_SUPPORTED_COLLS, - &mt, 1, &score); + mts, 2, &score); if (ucc_unlikely(UCC_OK != status)) { return status; }