Skip to content

Commit d3f2da1

Browse files
authored
[TransferEngine] Ascend supports asymmetric amount of registered memory (#758)
* [TransferEngine] Supports asymmetric amount of registered memory for ascend * format code * add env and check * refine
1 parent 4e055f3 commit d3f2da1

File tree

1 file changed

+44
-9
lines changed

1 file changed

+44
-9
lines changed

mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ bool printEnabled() {
5959
return env != nullptr && std::string(env) == "1";
6060
}
6161

62+
size_t getMaxRegMemoryNum() {
63+
static const size_t g_default_max_reg_memory_num = 8192;
64+
static char *env = getenv("ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM");
65+
if (env != nullptr) {
66+
return std::stoi(env);
67+
}
68+
return g_default_max_reg_memory_num;
69+
}
70+
6271
uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) {
6372
static std::random_device rand_gen;
6473
std::mt19937 gen(rand_gen());
@@ -646,14 +655,26 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info,
646655
}
647656
}
648657
}
658+
659+
size_t max_m_num = getMaxRegMemoryNum();
660+
if (m_num >= max_m_num) {
661+
LOG(ERROR) << "The number of registered memory exceeds the expected "
662+
"maximum size "
663+
<< max_m_num
664+
<< ". To resolve this issue, you can increase the maximum "
665+
"size by setting the environment variable "
666+
"ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM.";
667+
return -1;
668+
}
649669
hccl::TransportMem::RmaMemDescs localRmaMemDescs;
650670
localRmaMemDescs.array = rmaMemDescs.data();
651671
localRmaMemDescs.arrayLength = rmaMemDescs.size();
652672
uint32_t actualNumOfRemote = 0;
653-
std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray(m_num);
673+
std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray(
674+
max_m_num);
654675
hccl::TransportMem::RmaMemDescs remoteRmaMemDescs;
655676
remoteRmaMemDescs.array = remoteRmaMemDescArray.data();
656-
remoteRmaMemDescs.arrayLength = m_num;
677+
remoteRmaMemDescs.arrayLength = max_m_num;
657678
ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs,
658679
actualNumOfRemote);
659680
if (ret) {
@@ -662,8 +683,9 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info,
662683
<< ", remote_rank: " << remote_rank_info->devicePhyId;
663684
return ret;
664685
}
665-
std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray(m_num);
666-
for (uint32_t i = 0; i < m_num; ++i) {
686+
std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray(
687+
actualNumOfRemote);
688+
for (uint32_t i = 0; i < actualNumOfRemote; ++i) {
667689
ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i],
668690
remoteRmaMemArray[i]);
669691
if (ret) {
@@ -1021,14 +1043,26 @@ int transportMemAccept(RankInfo *local_rank_info) {
10211043
}
10221044
}
10231045
}
1046+
1047+
size_t max_m_num = getMaxRegMemoryNum();
1048+
if (m_num >= max_m_num) {
1049+
LOG(ERROR) << "The number of registered memory exceeds the expected "
1050+
"maximum size "
1051+
<< max_m_num
1052+
<< ". To resolve this issue, you can increase the maximum "
1053+
"size by setting the environment variable "
1054+
"ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM.";
1055+
return -1;
1056+
}
10241057
hccl::TransportMem::RmaMemDescs localRmaMemDescs;
10251058
localRmaMemDescs.array = rmaMemDescs.data();
10261059
localRmaMemDescs.arrayLength = rmaMemDescs.size();
10271060
uint32_t actualNumOfRemote = 0;
1028-
std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray(m_num);
1061+
std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray(
1062+
max_m_num);
10291063
hccl::TransportMem::RmaMemDescs remoteRmaMemDescs;
10301064
remoteRmaMemDescs.array = remoteRmaMemDescArray.data();
1031-
remoteRmaMemDescs.arrayLength = m_num;
1065+
remoteRmaMemDescs.arrayLength = max_m_num;
10321066
ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs,
10331067
actualNumOfRemote);
10341068
if (ret) {
@@ -1037,8 +1071,9 @@ int transportMemAccept(RankInfo *local_rank_info) {
10371071
<< ", remote_rank: " << remote_control_info.devicePhyId;
10381072
return ret;
10391073
}
1040-
std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray(m_num);
1041-
for (uint32_t i = 0; i < m_num; ++i) {
1074+
std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray(
1075+
actualNumOfRemote);
1076+
for (uint32_t i = 0; i < actualNumOfRemote; ++i) {
10421077
ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i],
10431078
remoteRmaMemArray[i]);
10441079
if (ret) {
@@ -1063,4 +1098,4 @@ int regLocalRmaMem(void *addr, uint64_t length) {
10631098

10641099
#ifdef __cplusplus
10651100
}
1066-
#endif // __cplusplus
1101+
#endif // __cplusplus

0 commit comments

Comments
 (0)