Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORE: fix score update when only score given #779

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/coll_score/ucc_coll_score.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -956,7 +956,8 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n)
int mt_n,
uint64_t colls)
{
ucc_status_t status;
int i, j;
Expand All @@ -967,6 +968,9 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
}

for (i = 0; i < UCC_COLL_TYPE_NUM; i++) {
if (!(colls & UCS_BIT(i))) {
continue;
}
for (j = 0; j < mt_n; j++) {
mt = (mtypes == NULL) ? (ucc_memory_type_t)j : mtypes[j];
status = ucc_coll_score_update_one(
Expand All @@ -980,25 +984,28 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
return UCC_OK;
}

ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_coll_score_t *score,
ucc_rank_t team_size,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n)
ucc_status_t
ucc_coll_score_update_from_str(const char *str,
const ucc_coll_score_team_info_t *info,
ucc_base_team_t *team,
ucc_coll_score_t *score)
{
ucc_status_t status;
ucc_coll_score_t *score_str;
status = ucc_coll_score_alloc_from_str(str, &score_str, team_size, init,
team, alg_fn);

status = ucc_coll_score_alloc_from_str(str, &score_str, info->size,
info->init, team, info->alg_fn);
if (UCC_OK != status) {
return status;
}
status = ucc_coll_score_update(score, score_str, def_score, mtypes, mt_n);

status = ucc_coll_score_update(score, score_str,
info->default_score,
info->supported_mem_types,
info->num_mem_types,
info->supported_colls);
ucc_coll_score_free(score_str);

return status;
}

Expand Down
57 changes: 33 additions & 24 deletions src/coll_score/ucc_coll_score.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -19,6 +19,30 @@

#define UCC_MSG_MAX UINT64_MAX

/* Callback that maps alg_id (int or str) to the "init" function.
This callback is provided by the component (CL/TL) that uses
ucc_coll_score_alloc_from_str.
Return values:
UCC_OK - input alg_id can be correctly mapped to the "init" fn
UCC_ERR_NOT_SUPPORTED - CL/TL doesn't allow changing algorithms ids for
the given coll_type, mem_type
UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */
typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id,
const char *alg_id_str,
ucc_coll_type_t coll_type,
ucc_memory_type_t mem_type,
ucc_base_coll_init_fn_t *init);

typedef struct ucc_coll_score_team_info {
ucc_score_t default_score;
ucc_rank_t size;
uint64_t supported_colls;
ucc_memory_type_t *supported_mem_types;
int num_mem_types;
ucc_base_coll_init_fn_t init;
ucc_alg_id_to_init_fn_t alg_fn;
} ucc_coll_score_team_info_t;

typedef struct ucc_coll_entry {
ucc_list_link_t list_elem;
ucc_score_t score;
Expand Down Expand Up @@ -67,19 +91,6 @@ ucc_status_t ucc_coll_score_merge(ucc_coll_score_t * score1,
ucc_coll_score_t * score2,
ucc_coll_score_t **rst, int free_inputs);

/* Callback that maps alg_id (int or str) to the "init" function.
This callback is provided by the component (CL/TL) that uses
ucc_coll_score_alloc_from_str.
Return values:
UCC_OK - input alg_id can be correctly mapped to the "init" fn
UCC_ERR_NOT_SUPPORTED - CL/TL does allow changing algorithms ids for
the given coll_type, mem_type
UCC_ERR_INVALID_PARAM - incorrect value of alg_id is provided */
typedef ucc_status_t (*ucc_alg_id_to_init_fn_t)(int alg_id,
const char * alg_id_str,
ucc_coll_type_t coll_type,
ucc_memory_type_t mem_type,
ucc_base_coll_init_fn_t *init);

/* Parses SCORE string (see ucc_base_iface.c for pattern description)
and initializes score data structure. team_size is used to filter
Expand Down Expand Up @@ -112,15 +123,12 @@ ucc_status_t ucc_coll_score_alloc_from_str(const char * str,
"init" is set to NULL. User provided ranges without alg_id will not
modify any existing "init" functions in that case and only change the
score of existing ranges*/
ucc_status_t ucc_coll_score_update_from_str(const char * str,
ucc_coll_score_t *score,
ucc_rank_t team_size,
ucc_base_coll_init_fn_t init,
ucc_base_team_t *team,
ucc_score_t def_score,
ucc_alg_id_to_init_fn_t alg_fn,
ucc_memory_type_t *mtypes,
int mt_n);

ucc_status_t
ucc_coll_score_update_from_str(const char *str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once merged, needs to be updated in TL/SHM as well.

const ucc_coll_score_team_info_t *info,
ucc_base_team_t *team,
ucc_coll_score_t *score);

ucc_status_t ucc_coll_score_merge_in(ucc_coll_score_t **dst,
ucc_coll_score_t *src);
Expand Down Expand Up @@ -159,6 +167,7 @@ ucc_status_t ucc_coll_score_update(ucc_coll_score_t *score,
ucc_coll_score_t *update,
ucc_score_t default_score,
ucc_memory_type_t *mtypes,
int mt_n);
int mt_n,
uint64_t colls);

#endif
13 changes: 10 additions & 3 deletions src/components/cl/basic/cl_basic_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,24 @@ ucc_status_t ucc_cl_basic_team_get_scores(ucc_base_team_t *cl_team,
ucc_cl_basic_team_t *team = ucc_derived_of(cl_team, ucc_cl_basic_team_t);
ucc_base_context_t *ctx = UCC_CL_TEAM_CTX(team);
ucc_status_t status;
ucc_coll_score_team_info_t team_info;

status = ucc_coll_score_dup(team->score, score);
if (UCC_OK != status) {
return status;
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, *score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_BASIC_DEFAULT_SCORE, NULL, NULL, 0);
team_info.alg_fn = NULL;
team_info.default_score = UCC_CL_BASIC_DEFAULT_SCORE;
team_info.init = NULL;
team_info.num_mem_types = 0;
team_info.supported_mem_types = NULL; /* all memory types supported*/
team_info.supported_colls = UCC_COLL_TYPE_ALL;
team_info.size = UCC_CL_TEAM_SIZE(team);

status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, *score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
29 changes: 22 additions & 7 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_cl_hier_alg_id_to_init;
team_info.default_score = UCC_CL_HIER_DEFAULT_SCORE;
team_info.init = ucc_cl_hier_coll_init;
team_info.num_mem_types = 0;
team_info.supported_mem_types = NULL; /* all memory types supported*/
team_info.supported_colls = UCC_COLL_TYPE_ALL;
team_info.size = UCC_CL_TEAM_SIZE(team);

status = ucc_coll_score_alloc(&score);
if (UCC_OK != status) {
Expand Down Expand Up @@ -353,10 +362,14 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
}

for (i = 0; i < UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR; i++) {
// status = ucc_coll_score_update_from_str(
// ucc_cl_hier_default_alg_select_str[i], score,
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
// UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super,
// UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0,
// UCC_COLL_TYPE_ALL);
status = ucc_coll_score_update_from_str(
ucc_cl_hier_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_cl_hier_coll_init, &team->super.super,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);
ucc_cl_hier_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (UCC_OK != status) {
cl_error(lib, "failed to apply default coll select setting: %s",
ucc_cl_hier_default_alg_select_str[i]);
Expand All @@ -365,10 +378,12 @@ ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0);

// status = ucc_coll_score_update_from_str(
// ctx->score_str, score, UCC_CL_TEAM_SIZE(team), NULL, cl_team,
// UCC_CL_HIER_DEFAULT_SCORE, ucc_cl_hier_alg_id_to_init, NULL, 0,
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
// UCC_COLL_TYPE_ALL);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* if INVALID_PARAM - user provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
20 changes: 13 additions & 7 deletions src/components/tl/cuda/tl_cuda_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_cuda_alg_id_to_init;
team_info.default_score = UCC_TL_CUDA_DEFAULT_SCORE;
team_info.init = ucc_tl_cuda_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_CUDA_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);

status =
ucc_coll_score_build_default(tl_team, UCC_TL_CUDA_DEFAULT_SCORE,
Expand All @@ -339,9 +348,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_CUDA_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_cuda_default_alg_select_str[i], score,
UCC_TL_TEAM_SIZE(team), ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
ucc_tl_cuda_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (UCC_OK != status) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -351,10 +359,8 @@ ucc_status_t ucc_tl_cuda_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_cuda_coll_init, &team->super.super,
UCC_TL_CUDA_DEFAULT_SCORE, ucc_tl_cuda_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
goto err;
Expand Down
15 changes: 11 additions & 4 deletions src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
ucc_memory_type_t mt = UCC_MEMORY_TYPE_HOST;
ucc_coll_score_t * score;
ucc_status_t status;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = NULL;
team_info.default_score = UCC_TL_MLX5_DEFAULT_SCORE;
team_info.init = NULL;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_MLX5_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);

status = ucc_coll_score_alloc(&score);
if (UCC_OK != status) {
Expand All @@ -154,10 +163,8 @@ ucc_status_t ucc_tl_mlx5_team_get_scores(ucc_base_team_t * tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team), NULL, tl_team,
UCC_TL_MLX5_DEFAULT_SCORE, NULL, &mt, 1);

status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
19 changes: 12 additions & 7 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,15 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_nccl_alg_id_to_init;
team_info.default_score = UCC_TL_NCCL_DEFAULT_SCORE;
team_info.init = ucc_tl_nccl_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_NCCL_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);
/* There can be a different logic for different coll_type/mem_type.
Right now just init everything the same way. */
status =
Expand All @@ -220,9 +228,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_NCCL_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_nccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super, UCC_TL_NCCL_DEFAULT_SCORE,
ucc_tl_nccl_alg_id_to_init, &mt, 1);
ucc_tl_nccl_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -241,10 +248,8 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_nccl_coll_init, &team->super.super,
UCC_TL_NCCL_DEFAULT_SCORE, ucc_tl_nccl_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
19 changes: 12 additions & 7 deletions src/components/tl/rccl/tl_rccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,15 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
ucc_coll_score_t *score;
ucc_status_t status;
int i;
ucc_coll_score_team_info_t team_info;

team_info.alg_fn = ucc_tl_rccl_alg_id_to_init;
team_info.default_score = UCC_TL_RCCL_DEFAULT_SCORE;
team_info.init = ucc_tl_rccl_coll_init;
team_info.num_mem_types = 1;
team_info.supported_mem_types = &mt;
team_info.supported_colls = UCC_TL_RCCL_SUPPORTED_COLLS;
team_info.size = UCC_TL_TEAM_SIZE(team);
/* There can be a different logic for different coll_type/mem_type.
Right now just init everything the same way. */
status =
Expand All @@ -214,9 +222,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,

for (i = 0; i < UCC_TL_RCCL_N_DEFAULT_ALG_SELECT_STR; i++) {
status = ucc_coll_score_update_from_str(
ucc_tl_rccl_default_alg_select_str[i], score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super, UCC_TL_RCCL_DEFAULT_SCORE,
ucc_tl_rccl_alg_id_to_init, &mt, 1);
ucc_tl_rccl_default_alg_select_str[i], &team_info,
&team->super.super, score);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(tl_team->context->lib,
"failed to apply default coll select setting: %s",
Expand All @@ -235,10 +242,8 @@ ucc_status_t ucc_tl_rccl_team_get_scores(ucc_base_team_t *tl_team,
}

if (strlen(ctx->score_str) > 0) {
status = ucc_coll_score_update_from_str(
ctx->score_str, score, UCC_TL_TEAM_SIZE(team),
ucc_tl_rccl_coll_init, &team->super.super,
UCC_TL_RCCL_DEFAULT_SCORE, ucc_tl_rccl_alg_id_to_init, &mt, 1);
status = ucc_coll_score_update_from_str(ctx->score_str, &team_info,
&team->super.super, score);
/* If INVALID_PARAM - User provided incorrect input - try to proceed */
if ((status < 0) && (status != UCC_ERR_INVALID_PARAM) &&
(status != UCC_ERR_NOT_SUPPORTED)) {
Expand Down
Loading