-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CL/HIER: fix allreduce rab pipeline #759
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,7 +78,7 @@ ucc_cl_hier_allreduce_rab_frag_setup(ucc_schedule_pipelined_t *schedule_p, | |
|
||
static ucc_status_t | ||
ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args, | ||
ucc_base_team_t * team, | ||
ucc_base_team_t *team, | ||
ucc_schedule_t **sched_p, int n_frags) | ||
{ | ||
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); | ||
|
@@ -99,10 +99,13 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args, | |
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), out, status); | ||
|
||
if (n_frags > 1) { | ||
args.max_frag_count = | ||
ucc_buffer_block_count(args.args.dst.info.count, n_frags, 0); | ||
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT; | ||
args.max_frag_count = ucc_buffer_block_count(args.args.dst.info.count, | ||
n_frags, 0); | ||
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT; | ||
} | ||
ucc_assert(SBGP_ENABLED(cl_team, NODE) || | ||
SBGP_ENABLED(cl_team, NODE_LEADERS)); | ||
|
||
if (SBGP_ENABLED(cl_team, NODE)) { | ||
ucc_assert(n_tasks == 0); | ||
if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) { | ||
|
@@ -143,18 +146,33 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args, | |
n_tasks++; | ||
} | ||
|
||
UCC_CHECK_GOTO(ucc_event_manager_subscribe( | ||
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0], | ||
ucc_task_start_handler), | ||
out, status); | ||
|
||
/* subscription logic is different depending on top level schedule type | ||
* being used | ||
*/ | ||
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[0]), out, status); | ||
for (i = 1; i < n_tasks; i++) { | ||
UCC_CHECK_GOTO( | ||
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED, | ||
tasks[i], ucc_task_start_handler), | ||
out, status); | ||
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status); | ||
if (n_frags > 1) { | ||
UCC_CHECK_GOTO(ucc_task_subscribe_dep(&schedule->super, tasks[0], | ||
UCC_EVENT_SCHEDULE_STARTED), | ||
out, status); | ||
for (i = 1; i < n_tasks; i++) { | ||
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status); | ||
UCC_CHECK_GOTO(ucc_task_subscribe_dep(tasks[i-1], tasks[i], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we just leave task_subscibe_dep for both cases? iirc, subscribe dep is the same ucc_event_manager_subscribe + n_deps initialization There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's what i did initially, but it doesn't work for persistent collectives because of deps counter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we start collective using schedule_pipelined_post, right? doesn't it reset deps_satisfied to 0? Why exactly dep counter is broken for persistent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it does, but we use schedule pipelined init only for pipelined schedule otherwise it will be ucc_schedule_start There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but this is what we call: ucc_cl_hier_rab_allreduce_start->schedule_pipelined_post. Am i missing smth? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as discussed over phone we can't do it because it will break pipelined schedules. Action items for next PRs would be to
|
||
UCC_EVENT_COMPLETED), | ||
out, status); | ||
} | ||
} else { | ||
UCC_CHECK_GOTO(ucc_event_manager_subscribe( | ||
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0], | ||
ucc_task_start_handler), | ||
out, status); | ||
for (i = 1; i < n_tasks; i++) { | ||
UCC_CHECK_GOTO( | ||
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED, | ||
tasks[i], ucc_task_start_handler), | ||
out, status); | ||
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, | ||
status); | ||
} | ||
} | ||
|
||
schedule->super.post = ucc_cl_hier_allreduce_rab_start; | ||
|
@@ -206,9 +224,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allreduce_rab_init, | |
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, | ||
ucc_coll_task_t **task) | ||
{ | ||
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); | ||
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg; | ||
ucc_cl_hier_schedule_t * schedule; | ||
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); | ||
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg; | ||
ucc_cl_hier_schedule_t *schedule; | ||
int n_frags, pipeline_depth; | ||
ucc_status_t status; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this if statement be combined with same if in line 101?