diff --git a/src/ctran/mapper/ctranMapper.cc b/src/ctran/mapper/ctranMapper.cc index b9996939d..e2fa9becd 100644 --- a/src/ctran/mapper/ctranMapper.cc +++ b/src/ctran/mapper/ctranMapper.cc @@ -8,12 +8,7 @@ #include "ctranMapperImpl.h" #include "comm.h" #include "nccl_cvars.h" - -static std::vector registerDurs; -static std::vector deregisterDurs; -static std::vector lookupHitDurs; -static std::vector lookupMissDurs; -static std::unordered_map allCommHashCtranMapperMap; +#include /* === BEGIN_NCCL_CVAR_INFO_BLOCK === @@ -71,7 +66,81 @@ static std::unordered_map allCommHashCtranMapperMap; === END_NCCL_CVAR_INFO_BLOCK === */ -ctranMapper::ctranMapper(ncclComm *comm) { +enum GlobalRegistDurationType { REG_MEM, DEREG_MEM, LOOKUP_HIT, LOOKUP_MISS }; + +static std::unordered_map + globalRegistDurationTypeNameMap = { + {REG_MEM, "registration"}, + {DEREG_MEM, "deregistration"}, + {LOOKUP_HIT, "lookup-hit"}, + {LOOKUP_MISS, "lookup-miss"}, +}; +static std::unordered_map allCommHashCtranMapperMap; +static std::unordered_map> + allCommRegistDurationsMap; +static std::mutex allCommMutex; + +static double sumDurations(std::vector& durs) { + double total = 0; + for (auto& dur : durs) { + total += dur; + } + return total; +} + +static void reportGlobalRegSnapshot(void) { + const std::lock_guard lock(allCommMutex); + + // Counts per communicator + for (auto& it : allCommHashCtranMapperMap) { + auto& mapper = it.second; + mapper->reportRegSnapshot(); + } + + // Timers accumulated from all communicators + for (auto& it : allCommRegistDurationsMap) { + auto& key = it.first; + auto& durs = it.second; + size_t numDurs = durs.size(); + if (numDurs) { + double totalLat = sumDurations(durs); + INFO( + NCCL_INIT, + "CTRAN-MAPPER: [register snapshot] total %s latency across all comms %.2f ms, average %.2f ms across %lu %s", + globalRegistDurationTypeNameMap[key].c_str(), + totalLat, + totalLat / numDurs, + numDurs, + globalRegistDurationTypeNameMap[key].c_str()); + } + } +} + +static void recordRegistDuration( + GlobalRegistDurationType key, + double duration) { + allCommMutex.lock(); + allCommRegistDurationsMap[key].push_back(duration); + + // Allow periodical snapshot report during long job running + bool shouldReport = false; + if (key == GlobalRegistDurationType::REG_MEM && + NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT > 0 && + (allCommRegistDurationsMap[key].size() % + NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT == + 0)) { + shouldReport = true; + } + allCommMutex.unlock(); + + // Call report after unlock since we will lock again inside + // reportGlobalRegSnapshot + if (shouldReport) { + reportGlobalRegSnapshot(); + } +} + +ctranMapper::ctranMapper(ncclComm* comm) { this->pimpl = std::unique_ptr(new impl()); /* mapperRegElemList */ @@ -142,61 +211,9 @@ ctranMapper::ctranMapper(ncclComm *comm) { return this->regMem(buf, len, hdl); }); if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) { + allCommMutex.lock(); allCommHashCtranMapperMap[this->commHash] = this; - } -} - -static double sumDurations(std::vector& durs) { - double total = 0; - for (auto& dur : durs) { - total += dur; - } - return total; -} - -void reportGlobalRegSnapshot(void) { - // Counts per communicator - for (auto& it : allCommHashCtranMapperMap) { - auto& mapper = it.second; - mapper->reportRegSnapshot(); - } - - // Timers accumulated from all communicators - if (registerDurs.size()) { - double totalRegisLat = sumDurations(registerDurs); - INFO( - NCCL_INIT, - "CTRAN-MAPPER: [register snapshot] total registration latency across all comms %.2f ms, average %.2f ms across %lu registrations", - totalRegisLat, - totalRegisLat / registerDurs.size(), - registerDurs.size()); - } - if (deregisterDurs.size()) { - double totalDeregistLat = sumDurations(deregisterDurs); - INFO( - NCCL_INIT, - "CTRAN-MAPPER: [register snapshot] total deregistration latency across all comms %.2f ms, average %.2f ms across %lu registrations", - totalDeregistLat, - totalDeregistLat / deregisterDurs.size(), - deregisterDurs.size()); - } - if (lookupHitDurs.size()) { - double totalLookupHitLat = sumDurations(lookupHitDurs); - INFO( - NCCL_INIT, - "CTRAN-MAPPER: [register snapshot] total hit lookup latency across all comms %.2f ms, average %.2f ms across %lu hits", - totalLookupHitLat, - totalLookupHitLat / lookupHitDurs.size(), - lookupHitDurs.size()); - } - if (lookupMissDurs.size()) { - double totalLookupMissLat = sumDurations(lookupMissDurs); - INFO( - NCCL_INIT, - "CTRAN-MAPPER: [register snapshot] total missed lookup latency across all comms %.2f ms, average %.2f ms across %lu misses", - totalLookupMissLat, - totalLookupMissLat / lookupMissDurs.size(), - lookupMissDurs.size()); + allCommMutex.unlock(); } } @@ -338,11 +355,18 @@ ctranMapper::~ctranMapper() { if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) { // Report summary of this communicator before destroying it this->reportRegSnapshot(); - allCommHashCtranMapperMap.erase(this->commHash); - // Report global counters only after all communicators have been destroyed - if (allCommHashCtranMapperMap.empty()) { - reportGlobalRegSnapshot(); + bool lastMapper = false; + allCommMutex.lock(); + allCommHashCtranMapperMap.erase(this->commHash); + lastMapper = allCommHashCtranMapperMap.empty(); + allCommMutex.unlock(); + + // Report global counters after all communicators have been destroyed + // Call report after unlock since we will lock again inside + // reportGlobalRegSnapshot + if (lastMapper) { + reportGlobalRegSnapshot(); } } @@ -374,17 +398,11 @@ ncclResult_t ctranMapper::impl::regMem(struct ctranMapperRegElem *mapperRegElem) if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) { this->numRegistrations++; this->totalNumRegistrations++; - registerDurs.push_back(dur.durationMs()); + recordRegistDuration(GlobalRegistDurationType::REG_MEM, dur.durationMs()); } INFO(NCCL_COLL, "CTRAN-MAPPER: register buffer %p len %ld", mapperRegElem->buf, mapperRegElem->len); - // Allow snapshot report during long job running - if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT > 0 && - registerDurs.size() % NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT == 0) { - reportGlobalRegSnapshot(); - } - exit: return res; } @@ -403,7 +421,7 @@ ncclResult_t ctranMapper::impl::deregMem(struct ctranMapperRegElem *mapperRegEle if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) { this->numRegistrations--; - deregisterDurs.push_back(dur.durationMs()); + recordRegistDuration(GlobalRegistDurationType::DEREG_MEM, dur.durationMs()); } INFO(NCCL_COLL, "CTRAN-MAPPER: deregiter buffer %p len %ld", mapperRegElem->buf, mapperRegElem->len); @@ -502,10 +520,12 @@ ncclResult_t ctranMapper::searchRegHandle(const void *buf, std::size_t len, void if (NCCL_CTRAN_REGISTER_REPORT_SNAPSHOT_COUNT >= 0) { if (lookupHit) { - lookupHitDurs.push_back(dur.durationMs()); + recordRegistDuration( + GlobalRegistDurationType::LOOKUP_HIT, dur.durationMs()); this->pimpl->totalNumRegLookupHit++; } else { - lookupMissDurs.push_back(dur.durationMs()); + recordRegistDuration( + GlobalRegistDurationType::LOOKUP_MISS, dur.durationMs()); this->pimpl->totalNumRegLookupMiss++; if (*dynamicRegist) { this->pimpl->totalNumDynamicRegistrations++;