@@ -59,6 +59,15 @@ bool printEnabled() {
5959 return env != nullptr && std::string (env) == " 1" ;
6060}
6161
62+ size_t getMaxRegMemoryNum () {
63+ static const size_t g_default_max_reg_memory_num = 8192 ;
64+ static char *env = getenv (" ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM" );
65+ if (env != nullptr ) {
66+ return std::stoi (env);
67+ }
68+ return g_default_max_reg_memory_num;
69+ }
70+
6271uint16_t findAvailableTcpPort (int &sockfd, bool use_ipv6) {
6372 static std::random_device rand_gen;
6473 std::mt19937 gen (rand_gen ());
@@ -646,14 +655,26 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info,
646655 }
647656 }
648657 }
658+
659+ size_t max_m_num = getMaxRegMemoryNum ();
660+ if (m_num >= max_m_num) {
661+ LOG (ERROR) << " The number of registered memory exceeds the expected "
662+ " maximum size "
663+ << max_m_num
664+ << " . To resolve this issue, you can increase the maximum "
665+ " size by setting the environment variable "
666+ " ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM." ;
667+ return -1 ;
668+ }
649669 hccl::TransportMem::RmaMemDescs localRmaMemDescs;
650670 localRmaMemDescs.array = rmaMemDescs.data ();
651671 localRmaMemDescs.arrayLength = rmaMemDescs.size ();
652672 uint32_t actualNumOfRemote = 0 ;
653- std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray (m_num);
673+ std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray (
674+ max_m_num);
654675 hccl::TransportMem::RmaMemDescs remoteRmaMemDescs;
655676 remoteRmaMemDescs.array = remoteRmaMemDescArray.data ();
656- remoteRmaMemDescs.arrayLength = m_num ;
677+ remoteRmaMemDescs.arrayLength = max_m_num ;
657678 ret = transport_mem->ExchangeMemDesc (localRmaMemDescs, remoteRmaMemDescs,
658679 actualNumOfRemote);
659680 if (ret) {
@@ -662,8 +683,9 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info,
662683 << " , remote_rank: " << remote_rank_info->devicePhyId ;
663684 return ret;
664685 }
665- std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray (m_num);
666- for (uint32_t i = 0 ; i < m_num; ++i) {
686+ std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray (
687+ actualNumOfRemote);
688+ for (uint32_t i = 0 ; i < actualNumOfRemote; ++i) {
667689 ret = transport_mem->EnableMemAccess (remoteRmaMemDescArray[i],
668690 remoteRmaMemArray[i]);
669691 if (ret) {
@@ -1021,14 +1043,26 @@ int transportMemAccept(RankInfo *local_rank_info) {
10211043 }
10221044 }
10231045 }
1046+
1047+ size_t max_m_num = getMaxRegMemoryNum ();
1048+ if (m_num >= max_m_num) {
1049+ LOG (ERROR) << " The number of registered memory exceeds the expected "
1050+ " maximum size "
1051+ << max_m_num
1052+ << " . To resolve this issue, you can increase the maximum "
1053+ " size by setting the environment variable "
1054+ " ASCEND_TRANSPORT_MAX_REG_MEMORY_NUM." ;
1055+ return -1 ;
1056+ }
10241057 hccl::TransportMem::RmaMemDescs localRmaMemDescs;
10251058 localRmaMemDescs.array = rmaMemDescs.data ();
10261059 localRmaMemDescs.arrayLength = rmaMemDescs.size ();
10271060 uint32_t actualNumOfRemote = 0 ;
1028- std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray (m_num);
1061+ std::vector<hccl::TransportMem::RmaMemDesc> remoteRmaMemDescArray (
1062+ max_m_num);
10291063 hccl::TransportMem::RmaMemDescs remoteRmaMemDescs;
10301064 remoteRmaMemDescs.array = remoteRmaMemDescArray.data ();
1031- remoteRmaMemDescs.arrayLength = m_num ;
1065+ remoteRmaMemDescs.arrayLength = max_m_num ;
10321066 ret = transport_mem->ExchangeMemDesc (localRmaMemDescs, remoteRmaMemDescs,
10331067 actualNumOfRemote);
10341068 if (ret) {
@@ -1037,8 +1071,9 @@ int transportMemAccept(RankInfo *local_rank_info) {
10371071 << " , remote_rank: " << remote_control_info.devicePhyId ;
10381072 return ret;
10391073 }
1040- std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray (m_num);
1041- for (uint32_t i = 0 ; i < m_num; ++i) {
1074+ std::vector<hccl::TransportMem::RmaMem> remoteRmaMemArray (
1075+ actualNumOfRemote);
1076+ for (uint32_t i = 0 ; i < actualNumOfRemote; ++i) {
10421077 ret = transport_mem->EnableMemAccess (remoteRmaMemDescArray[i],
10431078 remoteRmaMemArray[i]);
10441079 if (ret) {
@@ -1063,4 +1098,4 @@ int regLocalRmaMem(void *addr, uint64_t length) {
10631098
10641099#ifdef __cplusplus
10651100}
1066- #endif // __cplusplus
1101+ #endif // __cplusplus
0 commit comments