Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CL/HIER: fix allreduce rab pipeline #759

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/components/cl/hier/allreduce/allreduce_rab.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ ucc_cl_hier_allreduce_rab_frag_setup(ucc_schedule_pipelined_t *schedule_p,

static ucc_status_t
ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_base_team_t *team,
ucc_schedule_t **sched_p, int n_frags)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
Expand All @@ -99,10 +99,13 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), out, status);

if (n_frags > 1) {
args.max_frag_count =
ucc_buffer_block_count(args.args.dst.info.count, n_frags, 0);
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
args.max_frag_count = ucc_buffer_block_count(args.args.dst.info.count,
n_frags, 0);
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
}
ucc_assert(SBGP_ENABLED(cl_team, NODE) ||
SBGP_ENABLED(cl_team, NODE_LEADERS));

if (SBGP_ENABLED(cl_team, NODE)) {
ucc_assert(n_tasks == 0);
if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) {
Expand Down Expand Up @@ -143,18 +146,33 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
n_tasks++;
}

UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
out, status);

/* subscription logic is different depending on top level schedule type
* being used
*/
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[0]), out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED,
tasks[i], ucc_task_start_handler),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status);
if (n_frags > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this if statement be combined with same if in line 101?

UCC_CHECK_GOTO(ucc_task_subscribe_dep(&schedule->super, tasks[0],
UCC_EVENT_SCHEDULE_STARTED),
out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status);
UCC_CHECK_GOTO(ucc_task_subscribe_dep(tasks[i-1], tasks[i],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we just leave task_subscibe_dep for both cases? iirc, subscribe dep is the same ucc_event_manager_subscribe + n_deps initialization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's what i did initially, but it doesn't work for persistent collectives because of deps counter

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we start collective using schedule_pipelined_post, right? doesn't it reset deps_satisfied to 0? Why exactly dep counter is broken for persistent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does, but we use schedule pipelined init only for pipelined schedule otherwise it will be ucc_schedule_start

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this is what we call: ucc_cl_hier_rab_allreduce_start->schedule_pipelined_post. Am i missing smth?
Also even if we call ucc_schedule_start. I think it would be cleaner to add n_deps_satisfied = 0 to ucc_schedule_start and then still use "subscribe_dep" alone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed over phone we can't do it because it will break pipelined schedules. Action items for next PRs would be to

  1. check if bcast 2 step has similar bug
  2. check if use pipelined schedule only will not harm performance

UCC_EVENT_COMPLETED),
out, status);
}
} else {
UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED,
tasks[i], ucc_task_start_handler),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out,
status);
}
}

schedule->super.post = ucc_cl_hier_allreduce_rab_start;
Expand Down Expand Up @@ -206,9 +224,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allreduce_rab_init,
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg;
ucc_cl_hier_schedule_t * schedule;
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg;
ucc_cl_hier_schedule_t *schedule;
int n_frags, pipeline_depth;
ucc_status_t status;

Expand Down
3 changes: 2 additions & 1 deletion src/components/cl/hier/bcast/bcast_2step.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ ucc_cl_hier_bcast_2step_init_schedule(ucc_base_coll_args_t *coll_args,
UCC_EVENT_SCHEDULE_STARTED);
} else {
ucc_task_subscribe_dep(tasks[first_task],
tasks[(first_task + 1) % 2], UCC_EVENT_COMPLETED);
tasks[(first_task + 1) % 2],
UCC_EVENT_COMPLETED);
}
ucc_schedule_add_task(schedule, tasks[(first_task + 1) % 2]);
}
Expand Down
38 changes: 20 additions & 18 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,20 @@ typedef struct ucc_event_manager {
} ucc_event_manager_t;

enum {
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
/* executor is required for collective*/
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
/* user visible task */
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
/* stop executor in task complete*/
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
/* destroy executor in task complete */
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
/* if set task can be casted to scheulde */
UCC_COLL_TASK_FLAG_IS_SCHEDULE = UCC_BIT(5),
UCC_COLL_TASK_FLAG_IS_SCHEDULE = UCC_BIT(5),
/* if set task can be casted to scheulde */
UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE = UCC_BIT(6),

};

typedef struct ucc_coll_task {
Expand Down Expand Up @@ -99,16 +102,16 @@ typedef struct ucc_coll_task {
ucc_ee_executor_t *executor;
union {
/* used for st & locked mt progress queue */
ucc_list_link_t list_elem;
ucc_list_link_t list_elem;
/* used for lf mt progress queue */
ucc_lf_queue_elem_t lf_elem;
ucc_lf_queue_elem_t lf_elem;
};
uint8_t n_deps;
uint8_t n_deps_satisfied;
uint8_t n_deps_base;
double start_time; /* timestamp of the start time:
either post or triggered_post */
uint32_t seq_num;
uint8_t n_deps;
uint8_t n_deps_satisfied;
uint8_t n_deps_base;
/* timestamp of the start time: either post or triggered_post */
double start_time;
uint32_t seq_num;
} ucc_coll_task_t;

extern struct ucc_mpool_ops ucc_coll_task_mpool_ops;
Expand Down Expand Up @@ -156,7 +159,7 @@ ucc_status_t ucc_task_start_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task);
ucc_status_t ucc_schedule_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent, /* NOLINT */
ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task);

ucc_status_t ucc_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
Expand Down Expand Up @@ -227,13 +230,12 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
}

static inline ucc_status_t ucc_task_subscribe_dep(ucc_coll_task_t *target,
ucc_coll_task_t *subscriber,
ucc_event_t event)
ucc_coll_task_t *subscriber,
ucc_event_t event)
{
ucc_status_t status =
ucc_event_manager_subscribe(target, event, subscriber,
ucc_dependency_handler);

subscriber->n_deps++;
return status;
}
Expand Down
79 changes: 55 additions & 24 deletions src/schedule/ucc_schedule_pipelined.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "ucc_schedule.h"
#include "ucc_schedule_pipelined.h"
#include "coll_score/ucc_coll_score.h"
Expand All @@ -17,23 +19,23 @@ const char* ucc_pipeline_order_names[] = {
static ucc_status_t ucc_frag_start_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task)
{
ucc_schedule_pipelined_t *schedule =
ucc_derived_of(parent, ucc_schedule_pipelined_t);
ucc_schedule_t *frag = ucc_derived_of(task, ucc_schedule_t);
ucc_status_t status;
ucc_schedule_pipelined_t *schedule = ucc_derived_of(parent,
ucc_schedule_pipelined_t);
ucc_schedule_t *frag = ucc_derived_of(task, ucc_schedule_t);
ucc_status_t st;

task->start_time = parent->start_time;
if (schedule->frag_setup) {
status =
schedule->frag_setup(schedule, frag, schedule->n_frags_started);
if (UCC_OK != status) {
st = schedule->frag_setup(schedule, frag, schedule->n_frags_started);
if (ucc_unlikely(UCC_OK != st)) {
ucc_error("failed to setup fragment %d of pipelined schedule",
schedule->n_frags_started);
return status;
return st;
}
}
schedule->next_frag_to_post =
(schedule->next_frag_to_post + 1) % schedule->n_frags;

schedule->next_frag_to_post = (schedule->next_frag_to_post + 1) %
schedule->n_frags;
ucc_trace_req("sched %p started frag %p frag_num %d next_to_post %d",
schedule, frag, schedule->n_frags_started,
schedule->next_frag_to_post);
Expand Down Expand Up @@ -106,7 +108,11 @@ ucc_status_t ucc_schedule_pipelined_finalize(ucc_coll_task_t *task)
for (i = 0; i < schedule_p->n_frags; i++) {
schedule_p->frags[i]->super.finalize(&frags[i]->super);
}
ucc_recursive_spinlock_destroy(&schedule_p->lock);

if (UCC_TASK_THREAD_MODE(task) == UCC_THREAD_MULTIPLE) {
ucc_recursive_spinlock_destroy(&schedule_p->lock);
}

return UCC_OK;
}

Expand Down Expand Up @@ -140,12 +146,15 @@ ucc_status_t ucc_schedule_pipelined_post(ucc_coll_task_t *task)
return ucc_schedule_start(task);
}

ucc_status_t ucc_schedule_pipelined_init(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_schedule_frag_init_fn_t frag_init,
ucc_schedule_frag_setup_fn_t frag_setup, int n_frags, int n_frags_total,
ucc_pipeline_order_t order, ucc_schedule_pipelined_t *schedule)
ucc_status_t ucc_schedule_pipelined_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_schedule_frag_init_fn_t frag_init,
ucc_schedule_frag_setup_fn_t frag_setup,
int n_frags, int n_frags_total,
ucc_pipeline_order_t order,
ucc_schedule_pipelined_t *schedule)
{
ucc_event_t task_dependency_event = UCC_EVENT_LAST;
int i, j;
ucc_status_t status;
ucc_schedule_t **frags;
Expand All @@ -156,14 +165,37 @@ ucc_status_t ucc_schedule_pipelined_init(
return UCC_ERR_INVALID_PARAM;
}

if (n_frags > 1) {
/* determine dependency between frags */
switch (order) {
case UCC_PIPELINE_PARALLEL:
/* no dependency between tasks of different fragments */
task_dependency_event = UCC_EVENT_LAST;
break;
case UCC_PIPELINE_ORDERED:
/* next fragment starts when previous has started */
task_dependency_event = UCC_EVENT_TASK_STARTED;
break;
case UCC_PIPELINE_SEQUENTIAL:
/* next fragment starts when previous has completed */
task_dependency_event = UCC_EVENT_COMPLETED;
break;
default:
return UCC_ERR_INVALID_PARAM;
}
}

status = ucc_schedule_init(&schedule->super, coll_args, team);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to init pipelined schedule");
return status;
}

ucc_recursive_spinlock_init(&schedule->lock, 0);
if (UCC_TASK_THREAD_MODE(&schedule->super.super) == UCC_THREAD_MULTIPLE) {
ucc_recursive_spinlock_init(&schedule->lock, 0);
}

schedule->super.super.flags |= UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE;
schedule->super.n_tasks = n_frags_total;
schedule->n_frags = n_frags;
schedule->order = order;
Expand All @@ -175,7 +207,7 @@ ucc_status_t ucc_schedule_pipelined_init(
frags = schedule->frags;
for (i = 0; i < n_frags; i++) {
status = frag_init(coll_args, schedule, team, &frags[i]);
if (UCC_OK != status) {
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to initialize fragment for pipeline");
goto err;
}
Expand All @@ -186,16 +218,15 @@ ucc_status_t ucc_schedule_pipelined_init(
frags[i]->super.status = UCC_OPERATION_INITIALIZED;
frags[i]->super.super.status = UCC_OPERATION_INITIALIZED;
}

for (i = 0; i < n_frags; i++) {
for (j = 0; j < frags[i]->n_tasks; j++) {
frags[i]->tasks[j]->n_deps_base = frags[i]->tasks[j]->n_deps;
if (n_frags > 1 && UCC_PIPELINE_PARALLEL != order) {
if (task_dependency_event != UCC_EVENT_LAST) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(
frags[(i > 0) ? (i - 1) : (n_frags - 1)]->tasks[j],
UCC_PIPELINE_ORDERED == order
? UCC_EVENT_TASK_STARTED
: UCC_EVENT_COMPLETED, frags[i]->tasks[j],
frags[(i + n_frags - 1) % n_frags]->tasks[j],
task_dependency_event, frags[i]->tasks[j],
ucc_dependency_handler),
err, status);
frags[i]->tasks[j]->n_deps_base++;
Expand Down
5 changes: 4 additions & 1 deletion src/schedule/ucc_schedule_pipelined.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef UCC_SCHEDULE_PIPELINED_H_
#define UCC_SCHEDULE_PIPELINED_H_

#include "components/base/ucc_base_iface.h"

#define UCC_SCHEDULE_FRAG_MAX_TASKS 8
#define UCC_SCHEDULE_PIPELINED_MAX_FRAGS 4

typedef struct ucc_schedule_pipelined ucc_schedule_pipelined_t;

#define UCC_SCHEDULE_PIPELINED_MAX_FRAGS 4

/* frag_init is the callback provided by the user of pipelined
framework (e.g., TL that needs to build a pipeline) that is reponsible
Expand Down
11 changes: 10 additions & 1 deletion src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "ucc_coll_utils.h"
#include "components/base/ucc_base_iface.h"
#include "core/ucc_team.h"
#include "schedule/ucc_schedule_pipelined.h"

#define STR_TYPE_CHECK(_str, _p, _prefix) \
do { \
if ((0 == strcasecmp(_UCC_PP_MAKE_STRING(_p), _str))) { \
Expand Down Expand Up @@ -511,9 +513,16 @@ void ucc_coll_task_components_str(const ucc_coll_task_t *task, char *str,
size_t *len)
{
ucc_schedule_t *schedule;
ucc_schedule_pipelined_t *schedule_pipelined;
int i;

if (task->flags & UCC_COLL_TASK_FLAG_IS_SCHEDULE) {
if (task->flags & UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE) {
schedule_pipelined = ucc_derived_of(task, ucc_schedule_pipelined_t);
for (i = 0; i < schedule_pipelined->n_frags; i++) {
ucc_coll_task_components_str(&schedule_pipelined->frags[i]->super,
str, len);
}
} else if (task->flags & UCC_COLL_TASK_FLAG_IS_SCHEDULE) {
schedule = ucc_derived_of(task, ucc_schedule_t);
for (i = 0; i < schedule->n_tasks; i++) {
ucc_coll_task_components_str(schedule->tasks[i], str, len);
Expand Down
Loading