Skip to content

Commit

Permalink
TL/NCCL: make team init non blocking (openucx#772)
Browse files Browse the repository at this point in the history
* TL/NCCL: make team init non blocking

* TL/NCCL: support by version and nb finalize

* REVIEW: code review fixes
  • Loading branch information
shimmybalsam authored and janjust committed Jan 31, 2024
1 parent a2efc83 commit 4de13ec
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 18 deletions.
88 changes: 75 additions & 13 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "coll_score/ucc_coll_score.h"
#include "utils/arch/cuda_def.h"

#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
Expand Down Expand Up @@ -57,23 +60,53 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
{
tl_debug(self->super.super.context->lib, "finalizing tl team: %p", self);
if (self->nccl_comm) {
if (self->comm_state != UCC_OK) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(self->nccl_comm);
} else {
ncclCommDestroy(self->nccl_comm);
}
cudaStreamDestroy(self->stream);
}
}

UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_nccl_team_t, ucc_base_team_t);
UCC_CLASS_DEFINE(ucc_tl_nccl_team_t, ucc_tl_team_t);

ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);

#if NCCL_USE_NON_BLOCKING
ncclResult_t nccl_status;

if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) {
goto check_finalize;
}
#endif

if (team->nccl_comm) {
if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(team->nccl_comm);
} else {
#if NCCL_USE_NON_BLOCKING
ncclCommFinalize(team->nccl_comm);
check_finalize:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress) {
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
ncclCommAbort(team->nccl_comm);
return UCC_ERR_NO_MESSAGE;
} else {
ncclCommDestroy(team->nccl_comm);
}
team->comm_state = UCC_OK;
#else
ncclCommDestroy(team->nccl_comm);
#endif
}
cudaStreamDestroy(team->stream);
}

UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team);
return UCC_OK;
}
Expand All @@ -85,6 +118,14 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ncclResult_t nccl_status;
ncclUniqueId errorid;

#if NCCL_USE_NON_BLOCKING
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;

if (team->comm_state == UCC_INPROGRESS) {
goto ncclInitStage;
}
#endif

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
Expand All @@ -108,19 +149,40 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
#if NCCL_USE_NON_BLOCKING
nccl_cfg.blocking = 0;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
UCC_TL_TEAM_SIZE(team),
team->unique_id[0],
UCC_TL_TEAM_RANK(team),
&nccl_cfg);
if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) {
goto free_stream;
}
ncclInitStage:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress){
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
#else
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
#endif
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
goto free_stream;
}
ucc_free(team->unique_id);
tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
return UCC_OK;

free_stream:
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
#if NCCL_USE_NON_BLOCKING
ncclCommAbort(team->nccl_comm);
#endif
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
Expand Down
8 changes: 3 additions & 5 deletions test/mpi/test_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,9 @@ void UccTestMpi::destroy_team(ucc_test_team_t &team)
ucc_status_t status;

team.free_ee();
while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {
if (UCC_OK != status) {
std::cerr << "ucc_team_destroy failed\n";
break;
}
while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {}
if (UCC_OK != status) {
std::cerr << "ucc_team_destroy failed\n";
}
if (team.comm != MPI_COMM_WORLD) {
MPI_Comm_free(&team.comm);
Expand Down

0 comments on commit 4de13ec

Please sign in to comment.