Skip to content

Commit

Permalink
CORE: fix score update when only score given
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jun 1, 2023
1 parent 0d07b2f commit f4d0aae
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 120 deletions.
35 changes: 21 additions & 14 deletions src/coll_score/ucc_coll_score.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -956,7 +956,8 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n)
int mt_n,
uint64_t colls)
{
ucc_status_t status;
int i, j;
Expand All @@ -967,6 +968,9 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
}

for (i = 0; i < UCC_COLL_TYPE_NUM; i++) {
if (!(colls & UCS_BIT(i))) {
continue;
}
for (j = 0; j < mt_n; j++) {
mt = (mtypes == NULL) ? (ucc_memory_type_t)j : mtypes[j];
status = ucc_coll_score_update_one(
Expand All @@ -980,25 +984,28 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
return UCC_OK;
}

ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_coll_score_t *score,
ucc_rank_t team_size,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n)
ucc_status_t
ucc_coll_score_update_from_str(const char *str,
const ucc_coll_score_team_info_t *info,
ucc_base_team_t *team,
ucc_coll_score_t *score)
{
ucc_status_t status;
ucc_coll_score_t *score_str;
status = ucc_coll_score_alloc_from_str(str, &score_str, team_size, init,
team, alg_fn);

status = ucc_coll_score_alloc_from_str(str, &score_str, info->size,
info->init, team, info->alg_fn);
if (UCC_OK != status) {
return status;
}
status = ucc_coll_score_update(score, score_str, def_score, mtypes, mt_n);

status = ucc_coll_score_update(score, score_str,
info->default_score,
info->supported_mem_types,
info->num_mem_types,
info->supported_colls);
ucc_coll_score_free(score_str);

return status;
}

Expand Down
57 changes: 33 additions & 24 deletions src/coll_score/ucc_coll_score.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -19,6 +19,30 @@

#define UCC_MSG_MAX UINT64_MAX

/* Callback that maps alg_id (int or str) to the "init" function.
This callback is provided by the component (CL/TL) that uses
ucc_coll_score_alloc_from_str.
Return values:
UCC_OK - input alg_id can be correctly mapped to the "init" fn
UCC_ERR_NOT_SUPPORTED - CL/TL doesn't allow changing algorithms ids for
the given coll_type, mem_type
UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */
typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id,
const char *alg_id_str,
ucc_coll_type_t coll_type,
ucc_memory_type_t mem_type,
ucc_base_coll_init_fn_t *init);

typedef struct ucc_coll_score_team_info {
ucc_score_t default_score;
ucc_rank_t size;
uint64_t supported_colls;
ucc_memory_type_t *supported_mem_types;
int num_mem_types;
ucc_base_coll_init_fn_t init;
ucc_alg_id_to_init_fn_t alg_fn;
} ucc_coll_score_team_info_t;

typedef struct ucc_coll_entry {
ucc_list_link_t list_elem;
ucc_score_t score;
Expand Down Expand Up @@ -67,19 +91,6 @@ ucc_status_t ucc_coll_score_merge(ucc_coll_score_t * score1,
ucc_coll_score_t * score2,
ucc_coll_score_t **rst, int free_inputs);

/* Callback that maps alg_id (int or str) to the "init" function.
This callback is provided by the component (CL/TL) that uses
ucc_coll_score_alloc_from_str.
Return values:
UCC_OK - input alg_id can be correctly mapped to the "init" fn
UCC_ERR_NOT_SUPPORTED - CL/TL does allow changing algorithms ids for
the given coll_type, mem_type
UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */
typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id,
const char * alg_id_str,
ucc_coll_type_t coll_type,
ucc_memory_type_t mem_type,
ucc_base_coll_init_fn_t *init);

/* Parses SCORE string (see ucc_base_iface.c for pattern description)
and initializes score data structure. team_size is used to filter
Expand Down Expand Up @@ -112,15 +123,12 @@ ucc_status_t ucc_coll_score_alloc_from_str(const char * str,
"init" is set to NULL. User provided ranges without alg_id will not
modify any existing "init" functions in that case and only change the
score of existing ranges*/
ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_coll_score_t *score,
ucc_rank_t team_size,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n);

ucc_status_t
ucc_coll_score_update_from_str(const char *str,
const ucc_coll_score_team_info_t *info,
ucc_base_team_t *team,
ucc_coll_score_t *score);

ucc_status_t ucc_coll_score_merge_in(ucc_coll_score_t **dst,
ucc_coll_score_t *src);
Expand Down Expand Up @@ -159,6 +167,7 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n);
int mt_n,
uint64_t colls);

#endif
13 changes: 10 additions & 3 deletions src/components/cl/basic/cl_basic_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,24 @@ ucc_status_t ucc_cl_basic_team_get_scores(ucc_base_team_t *cl_team,
ucc_cl_basic_team_t *team = ucc_derived_of(cl_team, ucc_cl_basic_team_t);
ucc_base_context_t *ctx = UCC_CL_TEAM_CTX(team);
ucc_status_t status;
ucc_coll_score_team_info_t team_info;

status = ucc_coll_score_dup(team->score, score);
if (UCC_OK != status) {
return status;
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, *score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_BASIC_DEFAULT_SCORE, NULL, NULL, 0);
team_info.alg_fn = NULL;
team_info.default_score = UCC_CL_BASIC_DEFAULT_SCORE;
team_info.init = NULL;
team_info.num_mem_types = 0;
team_info.supported_mem_types = NULL; /* all memory types supported*/
team_info.supported_colls = UCC_COLL_TYPE_ALL;
team_info.size = UCC_CL_TEAM_SIZE(team);

status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, *score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
29 changes: 22 additions & 7 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_cl_hier_alg_id_to_init;
team_info.default_score = UCC_CL_HIER_DEFAULT_SCORE;
team_info.init = ucc_cl_hier_coll_init;
team_info.num_mem_types = 0;
team_info.supported_mem_types = NULL; /* all memory types supported*/
team_info.supported_colls = UCC_COLL_TYPE_ALL;
team_info.size = UCC_CL_TEAM_SIZE(team);

status = ucc_coll_score_alloc(&score);
if (UCC_OK != status) {
Expand Down Expand Up @@ -353,10 +362,14 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
}

for (i = 0; i < UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR; i++) {
// status = ucc_coll_score_update_from_str(
// ucc_cl_hier_default_alg_select_str[i], score,
// UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super,
// UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0,
// UCC_COLL_TYPE_ALL);
status = ucc_coll_score_update_from_str(
ucc_cl_hier_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);
ucc_cl_hier_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (UCC_OK != status) {
cl_error(lib, "failed to apply default coll select setting: %s",
ucc_cl_hier_default_alg_select_str[i]);
Expand All @@ -365,10 +378,12 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);

// status = ucc_coll_score_update_from_str(
// ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
// UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0,
// UCC_COLL_TYPE_ALL);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* if INVALID_PARAM - user provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
20 changes: 13 additions & 7 deletions src/components/tl/cuda/tl_cuda_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_cuda_alg_id_to_init;
team_info.default_score = UCC_TL_CUDA_DEFAULT_SCORE;
team_info.init = ucc_tl_cuda_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_CUDA_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);

status =
ucc_coll_score_build_default(tl_team, UCC_TL_CUDA_DEFAULT_SCORE,
Expand All @@ -339,9 +348,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_cuda_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
ucc_tl_cuda_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (UCC_OK != status) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -351,10 +359,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
goto err;
Expand Down
15 changes: 11 additions & 4 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST;
ucc_coll_score_t * score;
ucc_status_t status;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = NULL;
team_info.default_score = UCC_TL_MLX5_DEFAULT_SCORE;
team_info.init = NULL;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_MLX5_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);

status = ucc_coll_score_alloc(&score);
if (UCC_OK != status) {
Expand All @@ -59,10 +68,8 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL, tl_team,
UCC_TL_MLX5_DEFAULT_SCORE, NULL, &mt, 1);

status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
19 changes: 12 additions & 7 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,15 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_nccl_alg_id_to_init;
team_info.default_score = UCC_TL_NCCL_DEFAULT_SCORE;
team_info.init = ucc_tl_nccl_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_NCCL_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);
/* There can be a different logic for different coll_type/mem_type.
Right now just init everything the same way. */
status =
Expand All @@ -220,9 +228,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_NCCL_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_nccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super, UCC_TL_NCCL_DEFAULT_SCORE,
ucc_tl_nccl_alg_id_to_init, &mt, 1);
ucc_tl_nccl_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -241,10 +248,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super,
UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
19 changes: 12 additions & 7 deletions src/components/tl/rccl/tl_rccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,15 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_rccl_alg_id_to_init;
team_info.default_score = UCC_TL_RCCL_DEFAULT_SCORE;
team_info.init = ucc_tl_rccl_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_RCCL_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);
/* There can be a different logic for different coll_type/mem_type.
Right now just init everything the same way. */
status =
Expand All @@ -214,9 +222,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_RCCL_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_rccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super, UCC_TL_RCCL_DEFAULT_SCORE,
ucc_tl_rccl_alg_id_to_init, &mt, 1);
ucc_tl_rccl_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -235,10 +242,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super,
UCC_TL_RCCL_DEFAULT_SCORE, ucc_tl_rccl_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
Loading

0 comments on commit f4d0aae

Please sign in to comment.