diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index e2223b63b7a1d8..c7f983a849b76b 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -988,8 +988,8 @@ void ProcessGroupNCCL::Restart() { phi::distributed::P2POption p2p_opts = place_to_p2p_opts_.at(place_key); phi::distributed::CommContextManager::RecreateNCCLComm( store_, store_key, rank_, std::to_string(create_count_), &p2p_opts); - create_count_++; } + create_count_++; } void ProcessGroupNCCL::SyncCalcStream(const Place& place, diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index d28a72a631d9bf..283dcba005e3da 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -276,7 +277,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { uint64_t comm_seq_{0}; std::unordered_map p2p_comm_seq_; - std::unordered_map place_to_group_key_; + std::map place_to_group_key_; // TODO(sunyilun): attrs below will be removed later std::mutex mutex_; diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index ed3a0607a8fc45..741a06fa378b24 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -130,7 +130,7 @@ void CommContextManager::CreateNCCLCommContext( void CommContextManager::RecreateNCCLComm(const std::shared_ptr& store, const std::string& unique_comm_key, int rank, - const std::string& hash_key, + const std::string& recreate_key, const P2POption* p2p_opt) { auto& comm_context_manager = CommContextManager::GetInstance(); @@ -139,7 +139,8 @@ void CommContextManager::RecreateNCCLComm(const std::shared_ptr& store, PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); } - std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key; + std::string unique_key = + "NCCLCommContext/" + unique_comm_key + "/" + recreate_key; if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) { std::vector nccl_id_wrapper( reinterpret_cast(&nccl_id),