From d06f9d65ebe00e9e16157eff50dc594910b3f408 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Wed, 17 May 2023 11:13:30 +0000 Subject: [PATCH] CORE: fix score update when only score given --- src/coll_score/ucc_coll_score.c | 35 +++++++------ src/coll_score/ucc_coll_score.h | 57 +++++++++++++--------- src/components/cl/basic/cl_basic_team.c | 13 +++-- src/components/cl/hier/cl_hier_team.c | 29 ++++++++--- src/components/tl/cuda/tl_cuda_team.c | 20 +++++--- src/components/tl/mlx5/tl_mlx5_team.c | 15 ++++-- src/components/tl/nccl/tl_nccl_team.c | 19 +++++--- src/components/tl/rccl/tl_rccl_team.c | 19 +++++--- src/components/tl/self/tl_self_team.c | 15 ++++-- src/components/tl/sharp/tl_sharp_team.c | 16 ++++-- src/components/tl/ucp/tl_ucp_team.c | 43 ++++++++-------- test/gtest/coll_score/test_score_update.cc | 45 +++++++++++------ 12 files changed, 206 insertions(+), 120 deletions(-) diff --git a/src/coll_score/ucc_coll_score.c b/src/coll_score/ucc_coll_score.c index 5c9761eeba..7cc4f90af3 100644 --- a/src/coll_score/ucc_coll_score.c +++ b/src/coll_score/ucc_coll_score.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -956,7 +956,8 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score, ucc_coll_score_t *update, ucc_score_t default_score, ucc_memory_type_t *mtypes, - int mt_n) + int mt_n, + uint64_t colls) { ucc_status_t status; int i, j; @@ -967,6 +968,9 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score, } for (i = 0; i < UCC_COLL_TYPE_NUM; i++) { + if (!(colls & UCS_BIT(i))) { + continue; + } for (j = 0; j < mt_n; j++) { mt = (mtypes == NULL) ? (ucc_memory_type_t)j : mtypes[j]; status = ucc_coll_score_update_one( @@ -980,25 +984,28 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score, return UCC_OK; } -ucc_status_t ucc_coll_score_update_from_str(const char * str, - ucc_coll_score_t *score, - ucc_rank_t team_size, - ucc_base_coll_init_fn_t init, - ucc_base_team_t *team, - ucc_score_t def_score, - ucc_alg_id_to_init_fn_t alg_fn, - ucc_memory_type_t *mtypes, - int mt_n) +ucc_status_t +ucc_coll_score_update_from_str(const char *str, + const ucc_coll_score_team_info_t *info, + ucc_base_team_t *team, + ucc_coll_score_t *score) { ucc_status_t status; ucc_coll_score_t *score_str; - status = ucc_coll_score_alloc_from_str(str, &score_str, team_size, init, - team, alg_fn); + + status = ucc_coll_score_alloc_from_str(str, &score_str, info->size, + info->init, team, info->alg_fn); if (UCC_OK != status) { return status; } - status = ucc_coll_score_update(score, score_str, def_score, mtypes, mt_n); + + status = ucc_coll_score_update(score, score_str, + info->default_score, + info->supported_mem_types, + info->num_mem_types, + info->supported_colls); ucc_coll_score_free(score_str); + return status; } diff --git a/src/coll_score/ucc_coll_score.h b/src/coll_score/ucc_coll_score.h index 672e873196..16f0ba0b74 100644 --- a/src/coll_score/ucc_coll_score.h +++ b/src/coll_score/ucc_coll_score.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -19,6 +19,30 @@ #define UCC_MSG_MAX UINT64_MAX +/* Callback that maps alg_id (int or str) to the "init" function. + This callback is provided by the component (CL/TL) that uses + ucc_coll_score_alloc_from_str. + Return values: + UCC_OK - input alg_id can be correctly mapped to the "init" fn + UCC_ERR_NOT_SUPPORTED - CL/TL doesn't allow changing algorithms ids for + the given coll_type, mem_type + UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */ +typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id, + const char *alg_id_str, + ucc_coll_type_t coll_type, + ucc_memory_type_t mem_type, + ucc_base_coll_init_fn_t *init); + +typedef struct ucc_coll_score_team_info { + ucc_score_t default_score; + ucc_rank_t size; + uint64_t supported_colls; + ucc_memory_type_t *supported_mem_types; + int num_mem_types; + ucc_base_coll_init_fn_t init; + ucc_alg_id_to_init_fn_t alg_fn; +} ucc_coll_score_team_info_t; + typedef struct ucc_coll_entry { ucc_list_link_t list_elem; ucc_score_t score; @@ -67,19 +91,6 @@ ucc_status_t ucc_coll_score_merge(ucc_coll_score_t * score1, ucc_coll_score_t * score2, ucc_coll_score_t **rst, int free_inputs); -/* Callback that maps alg_id (int or str) to the "init" function. - This callback is provided by the component (CL/TL) that uses - ucc_coll_score_alloc_from_str. - Return values: - UCC_OK - input alg_id can be correctly mapped to the "init" fn - UCC_ERR_NOT_SUPPORTED - CL/TL does allow changing algorithms ids for - the given coll_type, mem_type - UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */ -typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id, - const char * alg_id_str, - ucc_coll_type_t coll_type, - ucc_memory_type_t mem_type, - ucc_base_coll_init_fn_t *init); /* Parses SCORE string (see ucc_base_iface.c for pattern description) and initializes score data structure. team_size is used to filter @@ -112,15 +123,12 @@ ucc_status_t ucc_coll_score_alloc_from_str(const char * str, "init" is set to NULL. User provided ranges without alg_id will not modify any existing "init" functions in that case and only change the score of existing ranges*/ -ucc_status_t ucc_coll_score_update_from_str(const char * str, - ucc_coll_score_t *score, - ucc_rank_t team_size, - ucc_base_coll_init_fn_t init, - ucc_base_team_t *team, - ucc_score_t def_score, - ucc_alg_id_to_init_fn_t alg_fn, - ucc_memory_type_t *mtypes, - int mt_n); + +ucc_status_t +ucc_coll_score_update_from_str(const char *str, + const ucc_coll_score_team_info_t *info, + ucc_base_team_t *team, + ucc_coll_score_t *score); ucc_status_t ucc_coll_score_merge_in(ucc_coll_score_t **dst, ucc_coll_score_t *src); @@ -159,6 +167,7 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score, ucc_coll_score_t *update, ucc_score_t default_score, ucc_memory_type_t *mtypes, - int mt_n); + int mt_n, + uint64_t colls); #endif diff --git a/src/components/cl/basic/cl_basic_team.c b/src/components/cl/basic/cl_basic_team.c index 9cad0956c7..04ca156c73 100644 --- a/src/components/cl/basic/cl_basic_team.c +++ b/src/components/cl/basic/cl_basic_team.c @@ -173,6 +173,7 @@ ucc_status_t ucc_cl_basic_team_get_scores(ucc_base_team_t *cl_team, ucc_cl_basic_team_t *team = ucc_derived_of(cl_team, ucc_cl_basic_team_t); ucc_base_context_t *ctx = UCC_CL_TEAM_CTX(team); ucc_status_t status; + ucc_coll_score_team_info_t team_info; status = ucc_coll_score_dup(team->score, score); if (UCC_OK != status) { @@ -180,10 +181,16 @@ ucc_status_t ucc_cl_basic_team_get_scores(ucc_base_team_t *cl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, *score, UCC_CL_TEAM_SIZE(team), NULL, cl_team, - UCC_CL_BASIC_DEFAULT_SCORE, NULL, NULL, 0); + team_info.alg_fn = NULL; + team_info.default_score = UCC_CL_BASIC_DEFAULT_SCORE; + team_info.init = NULL; + team_info.num_mem_types = 0; + team_info.supported_mem_types = NULL; /* all memory types supported*/ + team_info.supported_colls = UCC_COLL_TYPE_ALL; + team_info.size = UCC_CL_TEAM_SIZE(team); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, *score); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index c76dfebfb8..5abfebe489 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -315,6 +315,15 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team, ucc_coll_score_t *score; ucc_status_t status; int i; + ucc_coll_score_team_info_t team_info; + + team_info.alg_fn = ucc_cl_hier_alg_id_to_init; + team_info.default_score = UCC_CL_HIER_DEFAULT_SCORE; + team_info.init = ucc_cl_hier_coll_init; + team_info.num_mem_types = 0; + team_info.supported_mem_types = NULL; /* all memory types supported*/ + team_info.supported_colls = UCC_COLL_TYPE_ALL; + team_info.size = UCC_CL_TEAM_SIZE(team); status = ucc_coll_score_alloc(&score); if (UCC_OK != status) { @@ -353,10 +362,14 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team, } for (i = 0; i < UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR; i++) { + // status = ucc_coll_score_update_from_str( + // ucc_cl_hier_default_alg_select_str[i], score, + // UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super, + // UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0, + // UCC_COLL_TYPE_ALL); status = ucc_coll_score_update_from_str( - ucc_cl_hier_default_alg_select_str[i], score, - UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super, - UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0); + ucc_cl_hier_default_alg_select_str[i], &team_info, + &team->super.super, score); if (UCC_OK != status) { cl_error(lib, "failed to apply default coll select setting: %s", ucc_cl_hier_default_alg_select_str[i]); @@ -365,10 +378,12 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team, - UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0); - + // status = ucc_coll_score_update_from_str( + // ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team, + // UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0, + // UCC_COLL_TYPE_ALL); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); /* if INVALID_PARAM - user provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/tl/cuda/tl_cuda_team.c b/src/components/tl/cuda/tl_cuda_team.c index 4f0363f5f8..faa2ad89cf 100644 --- a/src/components/tl/cuda/tl_cuda_team.c +++ b/src/components/tl/cuda/tl_cuda_team.c @@ -327,6 +327,15 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team, ucc_coll_score_t *score; ucc_status_t status; int i; + ucc_coll_score_team_info_t team_info; + + team_info.alg_fn = ucc_tl_cuda_alg_id_to_init; + team_info.default_score = UCC_TL_CUDA_DEFAULT_SCORE; + team_info.init = ucc_tl_cuda_coll_init; + team_info.num_mem_types = 1; + team_info.supported_mem_types = &mt; + team_info.supported_colls = UCC_TL_CUDA_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); status = ucc_coll_score_build_default(tl_team, UCC_TL_CUDA_DEFAULT_SCORE, @@ -339,9 +348,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team, for (i = 0; i < UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR; i++) { status = ucc_coll_score_update_from_str( - ucc_tl_cuda_default_alg_select_str[i], score, - UCC_TL_TEAM_SIZE(team), ucc_tl_cuda_coll_init, &team->super.super, - UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1); + ucc_tl_cuda_default_alg_select_str[i], &team_info, + &team->super.super, score); if (UCC_OK != status) { tl_error(tl_team->context->lib, "failed to apply default coll select setting: %s", @@ -351,10 +359,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tl_cuda_coll_init, &team->super.super, - UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { goto err; diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 8c1c1dc439..bb109b45f2 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -146,6 +146,15 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST; ucc_coll_score_t * score; ucc_status_t status; + ucc_coll_score_team_info_t team_info; + + team_info.alg_fn = NULL; + team_info.default_score = UCC_TL_MLX5_DEFAULT_SCORE; + team_info.init = NULL; + team_info.num_mem_types = 1; + team_info.supported_mem_types = &mt; + team_info.supported_colls = UCC_TL_MLX5_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); status = ucc_coll_score_alloc(&score); if (UCC_OK != status) { @@ -154,10 +163,8 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL, tl_team, - UCC_TL_MLX5_DEFAULT_SCORE, NULL, &mt, 1); - + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index 686ec54c40..b2be47c074 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -207,7 +207,15 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, ucc_coll_score_t *score; ucc_status_t status; int i; + ucc_coll_score_team_info_t team_info; + team_info.alg_fn = ucc_tl_nccl_alg_id_to_init; + team_info.default_score = UCC_TL_NCCL_DEFAULT_SCORE; + team_info.init = ucc_tl_nccl_coll_init; + team_info.num_mem_types = 1; + team_info.supported_mem_types = &mt; + team_info.supported_colls = UCC_TL_NCCL_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); /* There can be a different logic for different coll_type/mem_type. Right now just init everything the same way. */ status = @@ -220,9 +228,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, for (i = 0; i < UCC_TL_NCCL_N_DEFAULT_ALG_SELECT_STR; i++) { status = ucc_coll_score_update_from_str( - ucc_tl_nccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team), - ucc_tl_nccl_coll_init, &team->super.super, UCC_TL_NCCL_DEFAULT_SCORE, - ucc_tl_nccl_alg_id_to_init, &mt, 1); + ucc_tl_nccl_default_alg_select_str[i], &team_info, + &team->super.super, score); if (ucc_unlikely(UCC_OK != status)) { tl_error(tl_team->context->lib, "failed to apply default coll select setting: %s", @@ -241,10 +248,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tl_nccl_coll_init, &team->super.super, - UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_alg_id_to_init, &mt, 1); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/tl/rccl/tl_rccl_team.c b/src/components/tl/rccl/tl_rccl_team.c index f8df307157..bfdb994d5e 100644 --- a/src/components/tl/rccl/tl_rccl_team.c +++ b/src/components/tl/rccl/tl_rccl_team.c @@ -201,7 +201,15 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team, ucc_coll_score_t *score; ucc_status_t status; int i; + ucc_coll_score_team_info_t team_info; + team_info.alg_fn = ucc_tl_rccl_alg_id_to_init; + team_info.default_score = UCC_TL_RCCL_DEFAULT_SCORE; + team_info.init = ucc_tl_rccl_coll_init; + team_info.num_mem_types = 1; + team_info.supported_mem_types = &mt; + team_info.supported_colls = UCC_TL_RCCL_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); /* There can be a different logic for different coll_type/mem_type. Right now just init everything the same way. */ status = @@ -214,9 +222,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team, for (i = 0; i < UCC_TL_RCCL_N_DEFAULT_ALG_SELECT_STR; i++) { status = ucc_coll_score_update_from_str( - ucc_tl_rccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team), - ucc_tl_rccl_coll_init, &team->super.super, UCC_TL_RCCL_DEFAULT_SCORE, - ucc_tl_rccl_alg_id_to_init, &mt, 1); + ucc_tl_rccl_default_alg_select_str[i], &team_info, + &team->super.super, score); if (ucc_unlikely(UCC_OK != status)) { tl_error(tl_team->context->lib, "failed to apply default coll select setting: %s", @@ -235,10 +242,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tl_rccl_coll_init, &team->super.super, - UCC_TL_RCCL_DEFAULT_SCORE, ucc_tl_rccl_alg_id_to_init, &mt, 1); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/tl/self/tl_self_team.c b/src/components/tl/self/tl_self_team.c index e384675b66..b29e19c811 100644 --- a/src/components/tl/self/tl_self_team.c +++ b/src/components/tl/self/tl_self_team.c @@ -52,11 +52,20 @@ ucc_status_t ucc_tl_self_team_get_scores(ucc_base_team_t *tl_team, ucc_memory_type_t mem_types[UCC_MEMORY_TYPE_LAST]; ucc_coll_score_t *score; ucc_status_t status; + ucc_coll_score_team_info_t team_info; for (i = 0; i < UCC_MEMORY_TYPE_LAST; i++) { mem_types[mt_n++] = (ucc_memory_type_t)i; } + team_info.alg_fn = NULL; + team_info.default_score = UCC_TL_SELF_DEFAULT_SCORE; + team_info.init = ucc_tl_self_coll_init; + team_info.num_mem_types = mt_n; + team_info.supported_mem_types = mem_types; /* all memory types supported*/ + team_info.supported_colls = UCC_TL_SELF_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); + status = ucc_coll_score_build_default( tl_team, UCC_TL_SELF_DEFAULT_SCORE, ucc_tl_self_coll_init, UCC_TL_SELF_SUPPORTED_COLLS, mem_types, mt_n, &score); @@ -66,10 +75,8 @@ ucc_status_t ucc_tl_self_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tl_self_coll_init, &team->super.super, - UCC_TL_SELF_DEFAULT_SCORE, NULL, NULL, 0); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { goto err; diff --git a/src/components/tl/sharp/tl_sharp_team.c b/src/components/tl/sharp/tl_sharp_team.c index 3d1190e62a..fe4a5875fb 100644 --- a/src/components/tl/sharp/tl_sharp_team.c +++ b/src/components/tl/sharp/tl_sharp_team.c @@ -260,7 +260,15 @@ ucc_status_t ucc_tl_sharp_team_get_scores(ucc_base_team_t *tl_team, ucc_base_context_t *ctx = UCC_TL_TEAM_CTX(team); ucc_coll_score_t *score; ucc_status_t status; - + ucc_coll_score_team_info_t team_info; + + team_info.alg_fn = NULL; + team_info.default_score = UCC_TL_SHARP_DEFAULT_SCORE; + team_info.init = ucc_tl_sharp_coll_init; + team_info.num_mem_types = 0; + team_info.supported_mem_types = NULL; /* all memory types supported*/ + team_info.supported_colls = UCC_TL_SHARP_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); /* There can be a different logic for different coll_type/mem_type. Right now just init everything the same way. */ status = @@ -273,10 +281,8 @@ ucc_status_t ucc_tl_sharp_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), - ucc_tl_sharp_coll_init, &team->super.super, - UCC_TL_SHARP_DEFAULT_SCORE, NULL, NULL, 0); + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); /* If INVALID_PARAM - User provided incorrect input - try to proceed */ if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && (status != UCC_ERR_NOT_SUPPORTED)) { diff --git a/src/components/tl/ucp/tl_ucp_team.c b/src/components/tl/ucp/tl_ucp_team.c index 97e9ad4da3..012fea5012 100644 --- a/src/components/tl/ucp/tl_ucp_team.c +++ b/src/components/tl/ucp/tl_ucp_team.c @@ -207,6 +207,7 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, unsigned i; char *ucc_tl_ucp_default_alg_select_str [UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR]; + ucc_coll_score_team_info_t team_info; for (i = 0; i < UCC_MEMORY_TYPE_LAST; i++) { if (tl_ctx->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[i])) { @@ -217,6 +218,14 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, } } + team_info.alg_fn = ucc_tl_ucp_alg_id_to_init; + team_info.default_score = UCC_TL_UCP_DEFAULT_SCORE; + team_info.init = ucc_tl_ucp_coll_init; + team_info.num_mem_types = mt_n; + team_info.supported_mem_types = mem_types; /* all memory types supported*/ + team_info.supported_colls = UCC_TL_UCP_SUPPORTED_COLLS; + team_info.size = UCC_TL_TEAM_SIZE(team); + /* There can be a different logic for different coll_type/mem_type. Right now just init everything the same way. */ status = ucc_coll_score_build_default(tl_team, UCC_TL_UCP_DEFAULT_SCORE, @@ -232,9 +241,8 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, } for (i = 0; i < UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR; i++) { status = ucc_coll_score_update_from_str( - ucc_tl_ucp_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team), - ucc_tl_ucp_coll_init, &team->super.super, UCC_TL_UCP_DEFAULT_SCORE, - ucc_tl_ucp_alg_id_to_init, mem_types, mt_n); + ucc_tl_ucp_default_alg_select_str[i], &team_info, + &team->super.super, score); if (UCC_OK != status) { tl_error(tl_team->context->lib, "failed to apply default coll select setting: %s", @@ -244,27 +252,16 @@ ucc_status_t ucc_tl_ucp_team_get_scores(ucc_base_team_t *tl_team, } if (strlen(ctx->score_str) > 0) { - status = ucc_coll_score_update_from_str( - ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL, - &team->super.super, UCC_TL_UCP_DEFAULT_SCORE, - ucc_tl_ucp_alg_id_to_init, mem_types, mt_n); - - /* If INVALID_PARAM - User provided incorrect input - try to proceed */ - if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && - (status != UCC_ERR_NOT_SUPPORTED)) { - goto err; - } + status = ucc_coll_score_update_from_str(ctx->score_str, &team_info, + &team->super.super, score); } else if (strlen(team->tuning_str) > 0) { - status = ucc_coll_score_update_from_str( - team->tuning_str, score, UCC_TL_TEAM_SIZE(team), NULL, - &team->super.super, UCC_TL_UCP_DEFAULT_SCORE, - ucc_tl_ucp_alg_id_to_init, mem_types, mt_n); - - /* If INVALID_PARAM - User provided incorrect input - try to proceed */ - if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && - (status != UCC_ERR_NOT_SUPPORTED)) { - goto err; - } + status = ucc_coll_score_update_from_str(team->tuning_str, &team_info, + &team->super.super, score); + } + /* If INVALID_PARAM - User provided incorrect input - try to proceed */ + if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) && + (status != UCC_ERR_NOT_SUPPORTED)) { + goto err; } for (i = 0; i < plugins->n_components; i++) { diff --git a/test/gtest/coll_score/test_score_update.cc b/test/gtest/coll_score/test_score_update.cc index ed0cd64728..f82a1e19cb 100644 --- a/test/gtest/coll_score/test_score_update.cc +++ b/test/gtest/coll_score/test_score_update.cc @@ -34,7 +34,8 @@ UCC_TEST_F(test_score_update, non_overlap) UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(10, 20, 100), RANGE(30, 35, 1)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 10, 10), RANGE(40, 50, 5)}))); @@ -44,7 +45,8 @@ UCC_TEST_F(test_score_update, overlap_single) { init_score(score, RLIST({RANGE(0, 100, 10)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(50, 150, 100)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 10), RANGE(50, 100, 100)}))); @@ -52,7 +54,8 @@ UCC_TEST_F(test_score_update, overlap_single) init_score(score, RLIST({RANGE(0, 100, 100)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(50, 150, 10)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 100), RANGE(50, 100, 10)}))); @@ -62,7 +65,8 @@ UCC_TEST_F(test_score_update, inclusive) { init_score(score, RLIST({RANGE(0, 90, 100)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(30, 60, 10)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 30, 100), RANGE(30, 60, 10), @@ -71,7 +75,8 @@ UCC_TEST_F(test_score_update, inclusive) init_score(score, RLIST({RANGE(0, 90, 10)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(30, 60, 100)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 30, 10), RANGE(30, 60, 100), @@ -82,7 +87,8 @@ UCC_TEST_F(test_score_update, same_start) { init_score(score, RLIST({RANGE(0, 100, 100)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(0, 50, 10)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 10), RANGE(50, 100, 100)}))); @@ -90,7 +96,8 @@ UCC_TEST_F(test_score_update, same_start) init_score(score, RLIST({RANGE(0, 100, 10)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(0, 50, 100)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 100), RANGE(50, 100, 10)}))); @@ -98,7 +105,8 @@ UCC_TEST_F(test_score_update, same_start) init_score(score, RLIST({RANGE(1, 100, 10)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(1, 50, 100)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(1, 50, 100), RANGE(50, 100, 10)}))); @@ -110,7 +118,8 @@ UCC_TEST_F(test_score_update, 1_overlaps_many) init_score(update, RLIST({RANGE(10, 20, 10), RANGE(30, 40, 20), RANGE(60, 70, 30)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 10, 100), RANGE(10, 20, 10), @@ -124,7 +133,8 @@ UCC_TEST_F(test_score_update, 1_overlaps_many) update, RLIST({RANGE(10, 20, 100), RANGE(30, 40, 100), RANGE(60, 70, 5)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 10, 10), RANGE(10, 20, 100), @@ -137,7 +147,8 @@ UCC_TEST_F(test_score_update, same_score) { init_score(score, RLIST({RANGE(0, 100, 100)}), UCC_COLL_TYPE_BARRIER); init_score(update, RLIST({RANGE(100, 200, 100)}), UCC_COLL_TYPE_BARRIER); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 100, 100)}))); @@ -149,7 +160,8 @@ UCC_TEST_F(test_score_update, non_overlap_2) 0x1); init_score(update, RLIST({RANGE(300, 400, 100)}), UCC_COLL_TYPE_BARRIER, 0x2, 0x2); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 100, 100), RANGE(300, 400, 100)}))); @@ -161,7 +173,8 @@ UCC_TEST_F(test_score_update, init_reset) 0x1); init_score(update, RLIST({RANGE(0, 100, 100)}), UCC_COLL_TYPE_BARRIER, 0x2, 0x2); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 100, 100)}))); @@ -173,7 +186,8 @@ UCC_TEST_F(test_score_update, init_reset) 0x1); init_score(update, RLIST({RANGE(50, 150, 50)}), UCC_COLL_TYPE_BARRIER, 0x2, 0x2); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 100), RANGE(50, 100, 50), @@ -185,7 +199,8 @@ UCC_TEST_F(test_score_update, init_reset) 0x1); init_score(update, RLIST({RANGE(0, 100, 50)}), UCC_COLL_TYPE_BARRIER, 0x2, 0x2); - EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0)); + EXPECT_EQ(UCC_OK, ucc_coll_score_update(score, update, 0, NULL, 0, + UCC_COLL_TYPE_ALL)); EXPECT_EQ(UCC_OK, check_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, RLIST({RANGE(0, 50, 50), RANGE(50, 100, 50),