Skip to content

Commit

Permalink
TL/UCP: initial allgather neighbor exchange implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ikryukov committed Aug 16, 2023
1 parent daeceb2 commit 2a24add
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ SUBDIRS = .
include makefile.coll_plugins.am
endif

allgather = \
allgather/allgather.h \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather = \
allgather/allgather.h \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather/allgather_neighbor.c \
allgather/allgather_knomial.c

allgatherv = \
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLGATHER_ALG_RING,
.name = "ring",
.desc = "O(N) Ring"},
[UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR] =
{.id = UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
.name = "neighbor",
.desc = "O(N) Neighbor Exchange N/2 steps"},
[UCC_TL_UCP_ALLGATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
13 changes: 13 additions & 0 deletions src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum {
UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL,
UCC_TL_UCP_ALLGATHER_ALG_RING,
UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
UCC_TL_UCP_ALLGATHER_ALG_LAST
};

Expand All @@ -33,6 +34,7 @@ static inline int ucc_tl_ucp_allgather_alg_from_str(const char *str)

ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task);

/* Ring */
ucc_status_t ucc_tl_ucp_allgather_ring_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
Expand All @@ -43,6 +45,17 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *task);

/* Neighbor Exchange */
ucc_status_t ucc_tl_ucp_allgather_neighbor_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);

ucc_status_t ucc_tl_ucp_allgather_neighbor_init_common(ucc_tl_ucp_task_t *task);

void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *task);

/* Uses allgather_kn_radix from config */
ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
Expand Down
211 changes: 211 additions & 0 deletions src/components/tl/ucp/allgather/allgather_neighbor.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#include "config.h"
#include "tl_ucp.h"
#include "allgather.h"
#include "core/ucc_progress_queue.h"
#include "tl_ucp_sendrecv.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include "components/mc/ucc_mc.h"

static ucc_rank_t ucc_tl_ucp_allgather_neighbor_get_send_block(ucc_subset_t *subset,
ucc_rank_t trank,
ucc_rank_t tsize,
int step)
{
return ucc_ep_map_eval(subset->map, (trank - step + tsize) % tsize);
}

static ucc_rank_t ucc_tl_ucp_allgather_neighbor_get_recv_block(ucc_subset_t *subset,
ucc_rank_t trank,
ucc_rank_t tsize,
int step)
{
return ucc_ep_map_eval(subset->map, (trank - step - 1 + tsize) % tsize);
}

static ucc_rank_t get_recv_from_rank(ucc_rank_t rank, ucc_rank_t size, int i)
{
int neighbor[2], offset_at_step[2], recv_data_from[2];
int even_rank;
even_rank = !(rank % 2);
if (even_rank) {
neighbor[0] = (rank + 1) % size;
neighbor[1] = (rank - 1 + size) % size;
recv_data_from[0] = rank;
recv_data_from[1] = rank;
offset_at_step[0] = (+2);
offset_at_step[1] = (-2);
} else {
neighbor[0] = (rank - 1 + size) % size;
neighbor[1] = (rank + 1) % size;
recv_data_from[0] = neighbor[0];
recv_data_from[1] = neighbor[0];
offset_at_step[0] = (-2);
offset_at_step[1] = (+2);
}
const int i_parity = i % 2;
return (recv_data_from[i_parity] + offset_at_step[i_parity] * ((i + 1) / 2) + size) % size;
}


ucc_status_t ucc_tl_ucp_allgather_neighbor_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h)
{
ucc_tl_ucp_task_t *task;
ucc_status_t status;

task = ucc_tl_ucp_init_task(coll_args, team);
status = ucc_tl_ucp_allgather_neighbor_init_common(task);
if (status != UCC_OK) {
ucc_tl_ucp_put_task(task);
}
*task_h = &task->super;
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_allgather_neighbor_init_common(ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_sbgp_t *sbgp;

if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
return UCC_ERR_NOT_SUPPORTED;
}

if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET)) {
if (team->cfg.use_reordering) {
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}
}

task->allgather_neighbor.get_send_block = ucc_tl_ucp_allgather_neighbor_get_send_block;
task->allgather_neighbor.get_recv_block = ucc_tl_ucp_allgather_neighbor_get_recv_block;
task->super.post = ucc_tl_ucp_allgather_neighbor_start;
task->super.progress = ucc_tl_ucp_allgather_neighbor_progress;

return UCC_OK;
}

/* Original implmenetation: https://github.com/open-mpi/ompi/blob/main/ompi/mca/coll/base/coll_base_allgather.c */
void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
ucc_rank_t neighbor[2];
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_rank_t i;
int even_rank;
void *tmprecv, *tmpsend;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}

even_rank = !(trank % 2);
if (even_rank)
{
neighbor[0] = (trank + 1) % tsize;
neighbor[1] = (trank - 1 + tsize) % tsize;
}
else
{
neighbor[0] = (trank - 1 + tsize) % tsize;
neighbor[1] = (trank + 1) % tsize;
}

while (task->tagged.send_posted < (tsize / 2))
{
i = task->tagged.send_posted;
const int i_parity = i % 2;

tmprecv = PTR_OFFSET(rbuf, get_recv_from_rank(trank, tsize, i) * data_size);
tmpsend = PTR_OFFSET(rbuf, get_recv_from_rank(trank, tsize, i - 1) * data_size);

/* Sendreceive */
UCPCHECK_GOTO(
ucc_tl_ucp_send_nb(tmpsend, 2 * data_size, rmem, neighbor[i_parity], team, task),
task, out);
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(tmprecv, 2 * data_size, rmem, neighbor[i_parity], team, task),
task, out);

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}
}

ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;

out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_neighbor_done", 0);
}

ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;
ucc_rank_t block;
int even_rank;
ucc_rank_t neighbor[2];
void *tmprecv, *tmpsend;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_neighbor_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
block = task->allgather_neighbor.get_send_block(&task->subset, trank, tsize,
0);
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
}

even_rank = !(trank % 2);
if (even_rank)
{
neighbor[0] = (trank + 1) % tsize;
neighbor[1] = (trank - 1 + tsize) % tsize;
}
else
{
neighbor[0] = (trank - 1 + tsize) % tsize;
neighbor[1] = (trank + 1) % tsize;
}

tmprecv = PTR_OFFSET(rbuf, neighbor[0] * data_size);
tmpsend = PTR_OFFSET(rbuf, trank * data_size);

/* Sendreceive */
UCPCHECK_GOTO(
ucc_tl_ucp_send_nb(tmpsend, data_size, rmem, neighbor[0], team, task),
task, out);
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(tmprecv, data_size, rmem, neighbor[0], team, task),
task, out);
out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_neighbor_start", 0);
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}
3 changes: 3 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLGATHER_ALG_RING:
*init = ucc_tl_ucp_allgather_ring_init;
break;
case UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR:
*init = ucc_tl_ucp_allgather_neighbor_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,23 @@ typedef struct ucc_tl_ucp_task {
ucc_rank_t tsize,
int step);
} allgather_ring;
struct {
/*
* get send/recv block depends on subset type being used.
* For service allgather we need to get context endpoints but keep
* subset numbering.
* For regular allgather with rank reordering both endpoints
* and blocks permutation are necessary.
*/
ucc_rank_t (*get_send_block)(ucc_subset_t *subset,
ucc_rank_t trank,
ucc_rank_t tsize,
int step);
ucc_rank_t (*get_recv_block)(ucc_subset_t *subset,
ucc_rank_t trank,
ucc_rank_t tsize,
int step);
} allgather_neighbor;
struct {
ucc_rank_t dist;
uint32_t radix;
Expand Down

0 comments on commit 2a24add

Please sign in to comment.