Skip to content

Commit

Permalink
CL/HIER: fix allreduce rab pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Mar 29, 2023
1 parent db5124c commit 8a98742
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 64 deletions.
56 changes: 37 additions & 19 deletions src/components/cl/hier/allreduce/allreduce_rab.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -78,7 +78,7 @@ ucc_cl_hier_allreduce_rab_frag_setup(ucc_schedule_pipelined_t *schedule_p,

static ucc_status_t
ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_base_team_t *team,
ucc_schedule_t **sched_p, int n_frags)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
Expand All @@ -99,10 +99,13 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
UCC_CHECK_GOTO(ucc_schedule_init(schedule, &args, team), out, status);

if (n_frags > 1) {
args.max_frag_count =
ucc_buffer_block_count(args.args.dst.info.count, n_frags, 0);
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
args.max_frag_count = ucc_buffer_block_count(args.args.dst.info.count,
n_frags, 0);
args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT;
}
ucc_assert(SBGP_ENABLED(cl_team, NODE) ||
SBGP_ENABLED(cl_team, NODE_LEADERS));

if (SBGP_ENABLED(cl_team, NODE)) {
ucc_assert(n_tasks == 0);
if (cl_team->top_sbgp == UCC_HIER_SBGP_NODE) {
Expand Down Expand Up @@ -143,18 +146,33 @@ ucc_cl_hier_allreduce_rab_init_schedule(ucc_base_coll_args_t *coll_args,
n_tasks++;
}

UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
out, status);

/* subscription logic is different depending on top level schedule type
* being used
*/
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[0]), out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED,
tasks[i], ucc_task_start_handler),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status);
if (n_frags > 1) {
UCC_CHECK_GOTO(ucc_task_subscribe_dep(&schedule->super, tasks[0],
UCC_EVENT_SCHEDULE_STARTED),
out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out, status);
UCC_CHECK_GOTO(ucc_task_subscribe_dep(tasks[i-1], tasks[i],
UCC_EVENT_COMPLETED),
out, status);
}
} else {
UCC_CHECK_GOTO(ucc_event_manager_subscribe(
&schedule->super, UCC_EVENT_SCHEDULE_STARTED, tasks[0],
ucc_task_start_handler),
out, status);
for (i = 1; i < n_tasks; i++) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(tasks[i - 1], UCC_EVENT_COMPLETED,
tasks[i], ucc_task_start_handler),
out, status);
UCC_CHECK_GOTO(ucc_schedule_add_task(schedule, tasks[i]), out,
status);
}
}

schedule->super.post = ucc_cl_hier_allreduce_rab_start;
Expand Down Expand Up @@ -207,9 +225,9 @@ UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_allreduce_rab_init,
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task)
{
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg;
ucc_cl_hier_schedule_t * schedule;
ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t);
ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg;
ucc_cl_hier_schedule_t *schedule;
int n_frags, pipeline_depth;
ucc_status_t status;

Expand Down
3 changes: 2 additions & 1 deletion src/components/cl/hier/bcast/bcast_2step.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ ucc_cl_hier_bcast_2step_init_schedule(ucc_base_coll_args_t *coll_args,
UCC_EVENT_SCHEDULE_STARTED);
} else {
ucc_task_subscribe_dep(tasks[first_task],
tasks[(first_task + 1) % 2], UCC_EVENT_COMPLETED);
tasks[(first_task + 1) % 2],
UCC_EVENT_COMPLETED);
}
ucc_schedule_add_task(schedule, tasks[(first_task + 1) % 2]);
}
Expand Down
38 changes: 20 additions & 18 deletions src/schedule/ucc_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,20 @@ typedef struct ucc_event_manager {
} ucc_event_manager_t;

enum {
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
UCC_COLL_TASK_FLAG_CB = UCC_BIT(0),
/* executor is required for collective*/
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
UCC_COLL_TASK_FLAG_EXECUTOR = UCC_BIT(1),
/* user visible task */
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
UCC_COLL_TASK_FLAG_TOP_LEVEL = UCC_BIT(2),
/* stop executor in task complete*/
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
UCC_COLL_TASK_FLAG_EXECUTOR_STOP = UCC_BIT(3),
/* destroy executor in task complete */
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
UCC_COLL_TASK_FLAG_EXECUTOR_DESTROY = UCC_BIT(4),
/* if set task can be casted to scheulde */
UCC_COLL_TASK_FLAG_IS_SCHEDULE = UCC_BIT(5),
UCC_COLL_TASK_FLAG_IS_SCHEDULE = UCC_BIT(5),
/* if set task can be casted to scheulde */
UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE = UCC_BIT(6),

};

typedef struct ucc_coll_task {
Expand Down Expand Up @@ -99,16 +102,16 @@ typedef struct ucc_coll_task {
ucc_ee_executor_t *executor;
union {
/* used for st & locked mt progress queue */
ucc_list_link_t list_elem;
ucc_list_link_t list_elem;
/* used for lf mt progress queue */
ucc_lf_queue_elem_t lf_elem;
ucc_lf_queue_elem_t lf_elem;
};
uint8_t n_deps;
uint8_t n_deps_satisfied;
uint8_t n_deps_base;
double start_time; /* timestamp of the start time:
either post or triggered_post */
uint32_t seq_num;
uint8_t n_deps;
uint8_t n_deps_satisfied;
uint8_t n_deps_base;
/* timestamp of the start time: either post or triggered_post */
double start_time;
uint32_t seq_num;
} ucc_coll_task_t;

extern struct ucc_mpool_ops ucc_coll_task_mpool_ops;
Expand Down Expand Up @@ -156,7 +159,7 @@ ucc_status_t ucc_task_start_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task);
ucc_status_t ucc_schedule_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent, /* NOLINT */
ucc_status_t ucc_dependency_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task);

ucc_status_t ucc_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
Expand Down Expand Up @@ -227,13 +230,12 @@ static inline ucc_status_t ucc_task_complete(ucc_coll_task_t *task)
}

static inline ucc_status_t ucc_task_subscribe_dep(ucc_coll_task_t *target,
ucc_coll_task_t *subscriber,
ucc_event_t event)
ucc_coll_task_t *subscriber,
ucc_event_t event)
{
ucc_status_t status =
ucc_event_manager_subscribe(target, event, subscriber,
ucc_dependency_handler);

subscriber->n_deps++;
return status;
}
Expand Down
79 changes: 55 additions & 24 deletions src/schedule/ucc_schedule_pipelined.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "ucc_schedule.h"
#include "ucc_schedule_pipelined.h"
#include "coll_score/ucc_coll_score.h"
Expand All @@ -17,23 +19,23 @@ const char* ucc_pipeline_order_names[] = {
static ucc_status_t ucc_frag_start_handler(ucc_coll_task_t *parent,
ucc_coll_task_t *task)
{
ucc_schedule_pipelined_t *schedule =
ucc_derived_of(parent, ucc_schedule_pipelined_t);
ucc_schedule_t *frag = ucc_derived_of(task, ucc_schedule_t);
ucc_status_t status;
ucc_schedule_pipelined_t *schedule = ucc_derived_of(parent,
ucc_schedule_pipelined_t);
ucc_schedule_t *frag = ucc_derived_of(task, ucc_schedule_t);
ucc_status_t st;

task->start_time = parent->start_time;
if (schedule->frag_setup) {
status =
schedule->frag_setup(schedule, frag, schedule->n_frags_started);
if (UCC_OK != status) {
st = schedule->frag_setup(schedule, frag, schedule->n_frags_started);
if (ucc_unlikely(UCC_OK != st)) {
ucc_error("failed to setup fragment %d of pipelined schedule",
schedule->n_frags_started);
return status;
return st;
}
}
schedule->next_frag_to_post =
(schedule->next_frag_to_post + 1) % schedule->n_frags;

schedule->next_frag_to_post = (schedule->next_frag_to_post + 1) %
schedule->n_frags;
ucc_trace_req("sched %p started frag %p frag_num %d next_to_post %d",
schedule, frag, schedule->n_frags_started,
schedule->next_frag_to_post);
Expand Down Expand Up @@ -106,7 +108,11 @@ ucc_status_t ucc_schedule_pipelined_finalize(ucc_coll_task_t *task)
for (i = 0; i < schedule_p->n_frags; i++) {
schedule_p->frags[i]->super.finalize(&frags[i]->super);
}
ucc_recursive_spinlock_destroy(&schedule_p->lock);

if (UCC_TASK_THREAD_MODE(task) == UCC_THREAD_MULTIPLE) {
ucc_recursive_spinlock_destroy(&schedule_p->lock);
}

return UCC_OK;
}

Expand Down Expand Up @@ -140,12 +146,15 @@ ucc_status_t ucc_schedule_pipelined_post(ucc_coll_task_t *task)
return ucc_schedule_start(task);
}

ucc_status_t ucc_schedule_pipelined_init(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_schedule_frag_init_fn_t frag_init,
ucc_schedule_frag_setup_fn_t frag_setup, int n_frags, int n_frags_total,
ucc_pipeline_order_t order, ucc_schedule_pipelined_t *schedule)
ucc_status_t ucc_schedule_pipelined_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_schedule_frag_init_fn_t frag_init,
ucc_schedule_frag_setup_fn_t frag_setup,
int n_frags, int n_frags_total,
ucc_pipeline_order_t order,
ucc_schedule_pipelined_t *schedule)
{
ucc_event_t task_dependency_event = UCC_EVENT_LAST;
int i, j;
ucc_status_t status;
ucc_schedule_t **frags;
Expand All @@ -156,14 +165,37 @@ ucc_status_t ucc_schedule_pipelined_init(
return UCC_ERR_INVALID_PARAM;
}

if (n_frags > 1) {
/* determine dependency between frags */
switch (order) {
case UCC_PIPELINE_PARALLEL:
/* no dependency between tasks of different fragments */
task_dependency_event = UCC_EVENT_LAST;
break;
case UCC_PIPELINE_ORDERED:
/* next fragment starts when previous has started */
task_dependency_event = UCC_EVENT_TASK_STARTED;
break;
case UCC_PIPELINE_SEQUENTIAL:
/* next fragment starts when previous has completed */
task_dependency_event = UCC_EVENT_COMPLETED;
break;
default:
return UCC_ERR_INVALID_PARAM;
}
}

status = ucc_schedule_init(&schedule->super, coll_args, team);
if (ucc_unlikely(status != UCC_OK)) {
ucc_error("failed to init pipelined schedule");
return status;
}

ucc_recursive_spinlock_init(&schedule->lock, 0);
if (UCC_TASK_THREAD_MODE(&schedule->super.super) == UCC_THREAD_MULTIPLE) {
ucc_recursive_spinlock_init(&schedule->lock, 0);
}

schedule->super.super.flags |= UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE;
schedule->super.n_tasks = n_frags_total;
schedule->n_frags = n_frags;
schedule->order = order;
Expand All @@ -175,7 +207,7 @@ ucc_status_t ucc_schedule_pipelined_init(
frags = schedule->frags;
for (i = 0; i < n_frags; i++) {
status = frag_init(coll_args, schedule, team, &frags[i]);
if (UCC_OK != status) {
if (ucc_unlikely(UCC_OK != status)) {
ucc_error("failed to initialize fragment for pipeline");
goto err;
}
Expand All @@ -186,16 +218,15 @@ ucc_status_t ucc_schedule_pipelined_init(
frags[i]->super.status = UCC_OPERATION_INITIALIZED;
frags[i]->super.super.status = UCC_OPERATION_INITIALIZED;
}

for (i = 0; i < n_frags; i++) {
for (j = 0; j < frags[i]->n_tasks; j++) {
frags[i]->tasks[j]->n_deps_base = frags[i]->tasks[j]->n_deps;
if (n_frags > 1 && UCC_PIPELINE_PARALLEL != order) {
if (task_dependency_event != UCC_EVENT_LAST) {
UCC_CHECK_GOTO(
ucc_event_manager_subscribe(
frags[(i > 0) ? (i - 1) : (n_frags - 1)]->tasks[j],
UCC_PIPELINE_ORDERED == order
? UCC_EVENT_TASK_STARTED
: UCC_EVENT_COMPLETED, frags[i]->tasks[j],
frags[(i + n_frags - 1) % n_frags]->tasks[j],
task_dependency_event, frags[i]->tasks[j],
ucc_dependency_handler),
err, status);
frags[i]->tasks[j]->n_deps_base++;
Expand Down
5 changes: 4 additions & 1 deletion src/schedule/ucc_schedule_pipelined.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef UCC_SCHEDULE_PIPELINED_H_
#define UCC_SCHEDULE_PIPELINED_H_

#include "components/base/ucc_base_iface.h"

#define UCC_SCHEDULE_FRAG_MAX_TASKS 8
#define UCC_SCHEDULE_PIPELINED_MAX_FRAGS 4

typedef struct ucc_schedule_pipelined ucc_schedule_pipelined_t;

#define UCC_SCHEDULE_PIPELINED_MAX_FRAGS 4

/* frag_init is the callback provided by the user of pipelined
framework (e.g., TL that needs to build a pipeline) that is reponsible
Expand Down
11 changes: 10 additions & 1 deletion src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "ucc_coll_utils.h"
#include "components/base/ucc_base_iface.h"
#include "core/ucc_team.h"
#include "schedule/ucc_schedule_pipelined.h"

#define STR_TYPE_CHECK(_str, _p, _prefix) \
do { \
if ((0 == strcasecmp(_UCC_PP_MAKE_STRING(_p), _str))) { \
Expand Down Expand Up @@ -511,9 +513,16 @@ void ucc_coll_task_components_str(const ucc_coll_task_t *task, char *str,
size_t *len)
{
ucc_schedule_t *schedule;
ucc_schedule_pipelined_t *schedule_pipelined;
int i;

if (task->flags & UCC_COLL_TASK_FLAG_IS_SCHEDULE) {
if (task->flags & UCC_COLL_TASK_FLAG_IS_PIPELINED_SCHEDULE) {
schedule_pipelined = ucc_derived_of(task, ucc_schedule_pipelined_t);
for (i = 0; i < schedule_pipelined->n_frags; i++) {
ucc_coll_task_components_str(&schedule_pipelined->frags[i]->super,
str, len);
}
} else if (task->flags & UCC_COLL_TASK_FLAG_IS_SCHEDULE) {
schedule = ucc_derived_of(task, ucc_schedule_t);
for (i = 0; i < schedule->n_tasks; i++) {
ucc_coll_task_components_str(schedule->tasks[i], str, len);
Expand Down
Loading

0 comments on commit 8a98742

Please sign in to comment.