Skip to content

Commit

Permalink
ensure thread-safety for register timer in ctran mapper (NVIDIA#33)
Browse files Browse the repository at this point in the history
Summary:

When concurrent collective/p2p are sent via multiple NCCL communicators, ctran mapper register/deregister/search paths can be called by multiple threads concurrently. Thus, we need ensure thread-safety for the global timer for registration.

This patch fixes it by adding mutex for all accesses to the global variables used.

Differential Revision: D51083701
  • Loading branch information
minsii authored and facebook-github-bot committed Nov 13, 2023
1 parent 7eac3ec commit 5bd01dd
Showing 1 changed file with 95 additions and 75 deletions.
170 changes: 95 additions & 75 deletions src/ctran/mapper/ctranMapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
#include "ctranMapperImpl.h"
#include "comm.h"
#include "nccl_cvars.h"

static std::vector<double> registerDurs;
static std::vector<double> deregisterDurs;
static std::vector<double> lookupHitDurs;
static std::vector<double> lookupMissDurs;
static std::unordered_map<uint64_t, ctranMapper*> allCommHashCtranMapperMap;
#include <unordered_map>

/*
=== BEGIN_NCCL_CVAR_INFO_BLOCK ===
Expand Down Expand Up @@ -71,7 +66,81 @@ static std::unordered_map<uint64_t, ctranMapper*> allCommHashCtranMapperMap;
=== END_NCCL_CVAR_INFO_BLOCK ===
*/

ctranMapper::ctranMapper(ncclComm *comm) {
enum GlobalRegistDurationType { REG_MEM, DEREG_MEM, LOOKUP_HIT, LOOKUP_MISS };

static std::unordered_map<GlobalRegistDurationType, std::string>
globalRegistDurationTypeNameMap = {
{REG_MEM, "registration"},
{DEREG_MEM, "deregistration"},
{LOOKUP_HIT, "lookup-hit"},
{LOOKUP_MISS, "lookup-miss"},
};
static std::unordered_map<uint64_t, ctranMapper*> allCommHashCtranMapperMap;
static std::unordered_map<GlobalRegistDurationType, std::vector<double>>
allCommRegistDurationsMap;
static std::mutex allCommMutex;

static double sumDurations(std::vector<double>& durs) {
double total = 0;
for (auto& dur : durs) {
total += dur;
}
return total;
}

static void reportGlobalRegSnapshot(void) {
const std::lock_guard<std::mutex> lock(allCommMutex);

// Counts per communicator
for (auto& it : allCommHashCtranMapperMap) {
auto& mapper = it.second;
mapper->reportRegSnapshot();
}

// Timers accumulated from all communicators
for (auto& it : allCommRegistDurationsMap) {
auto& key = it.first;
auto& durs = it.second;
size_t numDurs = durs.size();
if (numDurs) {
double totalLat = sumDurations(durs);
INFO(
NCCL_INIT,
"CTRAN-MAPPER: [register snapshot] total %s latency across all comms %.2f ms, average %.2f ms across %lu %s",
globalRegistDurationTypeNameMap[key].c_str(),
totalLat,
totalLat / numDurs,
numDurs,
globalRegistDurationTypeNameMap[key].c_str());
}
}
}

static void recordRegistDuration(
GlobalRegistDurationType key,
double duration) {
allCommMutex.lock();
allCommRegistDurationsMap[key].push_back(duration);

// Allow periodical snapshot report during long job running
bool shouldReport = false;
if (key == GlobalRegistDurationType::REG_MEM &&
NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT > 0 &&
(allCommRegistDurationsMap[key].size() %
NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT ==
0)) {
shouldReport = true;
}
allCommMutex.unlock();

// Call report after unlock since we will lock again inside
// reportGlobalRegSnapshot
if (shouldReport) {
reportGlobalRegSnapshot();
}
}

ctranMapper::ctranMapper(ncclComm* comm) {
this->pimpl = std::unique_ptr<impl>(new impl());

/* mapperRegElemList */
Expand Down Expand Up @@ -142,61 +211,9 @@ ctranMapper::ctranMapper(ncclComm *comm) {
return this->regMem(buf, len, hdl);
});
if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) {
allCommMutex.lock();
allCommHashCtranMapperMap[this->commHash] = this;
}
}

static double sumDurations(std::vector<double>& durs) {
double total = 0;
for (auto& dur : durs) {
total += dur;
}
return total;
}

void reportGlobalRegSnapshot(void) {
// Counts per communicator
for (auto& it : allCommHashCtranMapperMap) {
auto& mapper = it.second;
mapper->reportRegSnapshot();
}

// Timers accumulated from all communicators
if (registerDurs.size()) {
double totalRegisLat = sumDurations(registerDurs);
INFO(
NCCL_INIT,
"CTRAN-MAPPER: [register snapshot] total registration latency across all comms %.2f ms, average %.2f ms across %lu registrations",
totalRegisLat,
totalRegisLat / registerDurs.size(),
registerDurs.size());
}
if (deregisterDurs.size()) {
double totalDeregistLat = sumDurations(deregisterDurs);
INFO(
NCCL_INIT,
"CTRAN-MAPPER: [register snapshot] total deregistration latency across all comms %.2f ms, average %.2f ms across %lu registrations",
totalDeregistLat,
totalDeregistLat / deregisterDurs.size(),
deregisterDurs.size());
}
if (lookupHitDurs.size()) {
double totalLookupHitLat = sumDurations(lookupHitDurs);
INFO(
NCCL_INIT,
"CTRAN-MAPPER: [register snapshot] total hit lookup latency across all comms %.2f ms, average %.2f ms across %lu hits",
totalLookupHitLat,
totalLookupHitLat / lookupHitDurs.size(),
lookupHitDurs.size());
}
if (lookupMissDurs.size()) {
double totalLookupMissLat = sumDurations(lookupMissDurs);
INFO(
NCCL_INIT,
"CTRAN-MAPPER: [register snapshot] total missed lookup latency across all comms %.2f ms, average %.2f ms across %lu misses",
totalLookupMissLat,
totalLookupMissLat / lookupMissDurs.size(),
lookupMissDurs.size());
allCommMutex.unlock();
}
}

Expand Down Expand Up @@ -338,11 +355,18 @@ ctranMapper::~ctranMapper() {
if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) {
// Report summary of this communicator before destroying it
this->reportRegSnapshot();
allCommHashCtranMapperMap.erase(this->commHash);

// Report global counters only after all communicators have been destroyed
if (allCommHashCtranMapperMap.empty()) {
reportGlobalRegSnapshot();
bool lastMapper = false;
allCommMutex.lock();
allCommHashCtranMapperMap.erase(this->commHash);
lastMapper = allCommHashCtranMapperMap.empty();
allCommMutex.unlock();

// Report global counters after all communicators have been destroyed
// Call report after unlock since we will lock again inside
// reportGlobalRegSnapshot
if (lastMapper) {
reportGlobalRegSnapshot();
}
}

Expand Down Expand Up @@ -374,17 +398,11 @@ ncclResult_t ctranMapper::impl::regMem(struct ctranMapperRegElem *mapperRegElem)
if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) {
this->numRegistrations++;
this->totalNumRegistrations++;
registerDurs.push_back(dur.durationMs());
recordRegistDuration(GlobalRegistDurationType::REG_MEM, dur.durationMs());
}

INFO(NCCL_COLL, "CTRAN-MAPPER: register buffer %p len %ld", mapperRegElem->buf, mapperRegElem->len);

// Allow snapshot report during long job running
if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT > 0 &&
registerDurs.size() % NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT == 0) {
reportGlobalRegSnapshot();
}

exit:
return res;
}
Expand All @@ -403,7 +421,7 @@ ncclResult_t ctranMapper::impl::deregMem(struct ctranMapperRegElem *mapperRegEle

if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) {
this->numRegistrations--;
deregisterDurs.push_back(dur.durationMs());
recordRegistDuration(GlobalRegistDurationType::DEREG_MEM, dur.durationMs());
}

INFO(NCCL_COLL, "CTRAN-MAPPER: deregiter buffer %p len %ld", mapperRegElem->buf, mapperRegElem->len);
Expand Down Expand Up @@ -502,10 +520,12 @@ ncclResult_t ctranMapper::searchRegHandle(const void *buf, std::size_t len, void

if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) {
if (lookupHit) {
lookupHitDurs.push_back(dur.durationMs());
recordRegistDuration(
GlobalRegistDurationType::LOOKUP_HIT, dur.durationMs());
this->pimpl->totalNumRegLookupHit++;
} else {
lookupMissDurs.push_back(dur.durationMs());
recordRegistDuration(
GlobalRegistDurationType::LOOKUP_MISS, dur.durationMs());
this->pimpl->totalNumRegLookupMiss++;
if (*dynamicRegist) {
this->pimpl->totalNumDynamicRegistrations++;
Expand Down

0 comments on commit 5bd01dd

Please sign in to comment.