diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index b2be47c074..bf02ea1372 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -12,6 +12,9 @@ #include "coll_score/ucc_coll_score.h" #include "utils/arch/cuda_def.h" +#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) +#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB + UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) { @@ -57,16 +60,6 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t) { tl_debug(self->super.super.context->lib, "finalizing tl team: %p", self); - if (self->nccl_comm) { - if (self->comm_state != UCC_OK) { - /* if communication error was detected ncclCommAbort should be used - since ncclCommDestroy could block */ - ncclCommAbort(self->nccl_comm); - } else { - ncclCommDestroy(self->nccl_comm); - } - cudaStreamDestroy(self->stream); - } } UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_nccl_team_t, ucc_base_team_t); @@ -74,6 +67,46 @@ UCC_CLASS_DEFINE(ucc_tl_nccl_team_t, ucc_tl_team_t); ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) { + ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); + +#if NCCL_USE_NON_BLOCKING + ncclResult_t nccl_status; + + if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) { + goto check_finalize; + } +#endif + + if (team->nccl_comm) { + if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) { + /* if communication error was detected ncclCommAbort should be used + since ncclCommDestroy could block */ + ncclCommAbort(team->nccl_comm); + } else { +#if NCCL_USE_NON_BLOCKING + ncclCommFinalize(team->nccl_comm); +check_finalize: + ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (nccl_status == ncclInProgress) { + team->comm_state = UCC_INPROGRESS; + return UCC_INPROGRESS; + } + if (nccl_status != ncclSuccess) { + tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status, + ncclGetErrorString(nccl_status)); + ncclCommAbort(team->nccl_comm); + return UCC_ERR_NO_MESSAGE; + } else { + ncclCommDestroy(team->nccl_comm); + } + team->comm_state = UCC_OK; +#else + ncclCommDestroy(team->nccl_comm); +#endif + } + cudaStreamDestroy(team->stream); + } + UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team); return UCC_OK; } @@ -85,6 +118,14 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) ncclResult_t nccl_status; ncclUniqueId errorid; +#if NCCL_USE_NON_BLOCKING + ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER; + + if (team->comm_state == UCC_INPROGRESS) { + goto ncclInitStage; + } +#endif + status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req); if (status == UCC_INPROGRESS) { return UCC_INPROGRESS; @@ -108,12 +149,27 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream, cudaStreamNonBlocking), free_unique_id, status); +#if NCCL_USE_NON_BLOCKING + nccl_cfg.blocking = 0; + nccl_status = ncclCommInitRankConfig(&team->nccl_comm, + UCC_TL_TEAM_SIZE(team), + team->unique_id[0], + UCC_TL_TEAM_RANK(team), + &nccl_cfg); + if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) { + goto free_stream; + } +ncclInitStage: + ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (nccl_status == ncclInProgress){ + team->comm_state = UCC_INPROGRESS; + return UCC_INPROGRESS; + } +#else nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team), team->unique_id[0], UCC_TL_TEAM_RANK(team)); +#endif if (nccl_status != ncclSuccess) { - tl_debug(tl_team->context->lib, "NCCL error %d %s", - nccl_status, ncclGetErrorString(nccl_status)); - status = UCC_ERR_NO_MESSAGE; goto free_stream; } ucc_free(team->unique_id); @@ -121,6 +177,12 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) return UCC_OK; free_stream: + tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status, + ncclGetErrorString(nccl_status)); + status = UCC_ERR_NO_MESSAGE; +#if NCCL_USE_NON_BLOCKING + ncclCommAbort(team->nccl_comm); +#endif cudaStreamDestroy(team->stream); free_unique_id: ucc_free(team->unique_id); diff --git a/test/mpi/test_mpi.cc b/test/mpi/test_mpi.cc index be3464e0e4..cfbae3a1d6 100644 --- a/test/mpi/test_mpi.cc +++ b/test/mpi/test_mpi.cc @@ -246,11 +246,9 @@ void UccTestMpi::destroy_team(ucc_test_team_t &team) ucc_status_t status; team.free_ee(); - while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) { - if (UCC_OK != status) { - std::cerr << "ucc_team_destroy failed\n"; - break; - } + while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {} + if (UCC_OK != status) { + std::cerr << "ucc_team_destroy failed\n"; } if (team.comm != MPI_COMM_WORLD) { MPI_Comm_free(&team.comm);