Skip to content

Commit

Permalink
[PGNCCL] Fix bugs in non-blocking mode (#137741)
Browse files Browse the repository at this point in the history
### Fix 1: Throw async error during init wait

Previously we just busy wait for `ncclSuccess`, if the nonblocking init encountered error, we never report that. Added detection of async error via `ncclGetAsyncError`.

### Fix 2: Add wait after comm split

```
  // After calling ncclCommSplit in non-blocking mode, we should wait for the
  // source communicator to be out of ncclInProgress state.
  // Reason 1:
  //   it's unsafe to call new operations on the parent comm while it's in
  //   ncclInProgress state.
  // Reason 2:
  //   as of NCCL 2.23, the ptr value of child comm will not be filled until the
  //   state of parent comm is ncclSuccess. This may change in the future. See:
  //   NVIDIA/nccl#1472
```
This wait does not mean the child comm is ready for use, neither does it block till that point.

Pull Request resolved: #137741
Approved by: https://github.com/shuqiangzhang
  • Loading branch information
kwen2501 authored and pytorchmergebot committed Oct 15, 2024
1 parent 370d66d commit 35fc24f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 77 deletions.
66 changes: 33 additions & 33 deletions torch/csrc/distributed/c10d/NCCLUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@

#include <nlohmann/json.hpp>

namespace {
constexpr int64_t kCommInitBusyWaitMillis = 10;
} // namespace

namespace c10d {

ncclComm_t NCCLComm::getNcclComm() {
Expand All @@ -34,37 +30,25 @@ ncclComm_t NCCLComm::getNcclComm() {
". ",
commFailureMsg));
}
// only wait for initialization if nonblocking mode is enabled
if (!initialized_ && nccl_use_nonblocking()) {
waitUntilInitialized(nccl_nonblocking_timeout());
if (!initialized_) {
waitUntilInitialized();
}

return ncclComm_;
}

void NCCLComm::waitUntilInitialized(int timeoutSecs) {
auto startTimepoint = std::chrono::steady_clock::now();
while (!initialized_) {
if (ncclComm_) {
ncclResult_t result{};
ncclCommGetAsyncError(ncclComm_, &result);
if (result == ncclSuccess) {
LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized.";
initialized_ = true;
break;
}
}
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > timeoutSecs) {
std::string err = "NCCL timeout in communicator initialization.";
TORCH_CHECK_WITH(DistBackendError, false, err);
}
std::this_thread::sleep_for(
std::chrono::milliseconds(kCommInitBusyWaitMillis));
}
void NCCLComm::waitUntilInitialized() {
// Wait for initialization to complete if in nonblocking mode.
// If timeout is reached, throw an exception.
if (nccl_use_nonblocking()) {
C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt);
// ncclComm_ should be initialized by now
}
// In blocking mode, this function is nothing but setting initialized_ to
// true.
initialized_ = true;
LOG(INFO) << "Rank " << rank_ << ": NCCL communicator " << repr()
<< " is initialized.";
}

// TODO: why do we have `!defined(FBCODE_CAFFE2)` here?
Expand All @@ -80,14 +64,30 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
auto comm = std::make_shared<NCCLComm>();
// This call will block until the source communicator is initialized
auto sourceComm = source->getNcclComm();
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config),
std::nullopt);
#else
// After calling ncclCommSplit in non-blocking mode, we should wait for the
// source communicator to be out of ncclInProgress state.
// Reason 1:
// it's unsafe to call new operations on the parent comm while it's in
// ncclInProgress state.
// Reason 2:
// as of NCCL 2.23, the ptr value of child comm will not be filled until the
// state of parent comm is ncclSuccess. This may change in the future. See:
// https://github.com/NVIDIA/nccl/issues/1472
C10D_NCCL_CHECK_TIMEOUT_SLEEP(
ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config),
sourceComm, // wait on parent comm
std::nullopt);
// comm->ncclComm_ should have valid ptr by now, but not necessarily
// initialized. Rely on getNcclComm() -> waitUntilInitialized() to wait for
// its initialization.
#endif
++source->ncclCommSplitCounter_;
comm->rank_ = rank;
if (!nccl_use_nonblocking()) {
comm->initialized_ = true;
}
return comm;
}
#endif
Expand Down
29 changes: 26 additions & 3 deletions torch/csrc/distributed/c10d/NCCLUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#ifdef USE_C10D_NCCL

#include <sched.h>
#include <stdio.h>
#include <stdlib.h>

Expand All @@ -16,6 +17,8 @@
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <optional>

constexpr int64_t kCommInitBusyWaitMillis = 2;

#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 14)
#define NCCL_HAS_COMM_NONBLOCKING
Expand Down Expand Up @@ -102,7 +105,7 @@
} while (0)

// Macro to throw on a non-successful NCCL return value, non-blocking.
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
#define C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, yield_fn) \
ncclResult_t result = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
while (result == ncclInProgress) { \
Expand All @@ -119,6 +122,7 @@
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
yield_fn; \
ncclCommGetAsyncError(comm, &result); \
} \
if (result != ncclSuccess) { \
Expand All @@ -128,6 +132,24 @@
TORCH_CHECK_WITH(DistBackendError, false, err); \
}

#define C10D_SCHED_SLEEP() \
std::this_thread::sleep_for( \
std::chrono::milliseconds(kCommInitBusyWaitMillis))

// Macro to throw exception on a non-successful NCCL return value or timeout.
// This macro uses sched_yield() to yield the CPU.
// Thus suitable for NCCL calls that would quickly turn ncclSuccess, e.g.
// collectives.
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, sched_yield())

// Macro to throw exception on a non-successful NCCL return value or timeout.
// This macro uses sleep to yield the CPU.
// Thus suitable for NCCL calls that would take longer to turn ncclSuccess, e.g.
// ncclCommInitRankConfig, ncclCommFinalize, etc.
#define C10D_NCCL_CHECK_TIMEOUT_SLEEP(cmd, comm, failureReason) \
C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, C10D_SCHED_SLEEP())

#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \
Expand All @@ -146,6 +168,7 @@
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} \
sched_yield(); \
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
} while (state == ncclInProgress); \
} \
Expand Down Expand Up @@ -518,8 +541,6 @@ class NCCLComm {
friend class ProcessGroupNCCL;

protected:
// a helper function to wait until the communicator is initialized;
void waitUntilInitialized(int timeoutSecs);
// Unique nccl_id for this communicator.
ncclUniqueId ncclId_;
bool aborted_;
Expand All @@ -539,6 +560,8 @@ class NCCLComm {

private:
ncclComm_t ncclComm_;
// a helper function to wait until the communicator is initialized;
void waitUntilInitialized();
};

// Helper that automatically cleans up premul sums.
Expand Down
41 changes: 0 additions & 41 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9737,47 +9737,6 @@ def forward(self, inp):
ddp._check_reducer_finalized()
ddp(input)

@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl",
"TORCH_NCCL_USE_COMM_NONBLOCKING only applies to NCCL"
)
def test_nccl_init_abort(self):
"""
Tests that we can abort a NCCL communicator during initialization and
recover appropriately.
"""
# Reinitialize global process group with TORCH_NCCL_USE_COMM_NONBLOCKING=1
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
dist.destroy_process_group()
timeout = timedelta(seconds=1)
dist.init_process_group(
init_method=INIT_METHOD,
backend=BACKEND,
world_size=int(os.environ["WORLD_SIZE"]),
rank=self.rank,
timeout=timeout,
)

# Abort pg in background thread.
running = True

def abort(device):
pg = _get_default_group()
while running:
pg._get_backend(torch.device(device))._shutdown()
time.sleep(1)

if self.rank != 1:
import threading
t = threading.Thread(target=abort, args=(self.rank,))
t.start()
with self.assertRaises(RuntimeError):
# First collective triggers initialization via ncclCommInitRank.
torch.distributed.barrier()
running = False
t.join()

def _run_ddp_update_process_group(self, new_pg):
def get_num_torch_recompiles():
guard_failures = torch._dynamo.utils.guard_failures
Expand Down

0 comments on commit 35fc24f

Please sign in to comment.