Skip to content

Commit

Permalink
Merge branch 'incubate/new_frl' of https://github.com/PaddlePaddle/Pa…
Browse files Browse the repository at this point in the history
…ddle into infer_merge_to_train
  • Loading branch information
wufeisheng committed Nov 27, 2023
2 parents 8465edf + c940ea7 commit 446cfd2
Show file tree
Hide file tree
Showing 18 changed files with 601 additions and 78 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/collective/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
const char* global_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
global_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
global_rank_ = std::atoi(global_rank);
}

// TODO(sunyilun): methods below will be removed later
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/collective/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ class ProcessGroup {
}

protected:
int global_rank_{-1};
int rank_;
int size_;
int gid_;
Expand Down
64 changes: 60 additions & 4 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/comm_task_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_task.h"
#include "paddle/phi/core/distributed/nccl_tools.h"
Expand Down Expand Up @@ -819,6 +820,7 @@ void ProcessGroupNCCL::BroadcastUniqueNCCLID(ncclUniqueId* nccl_id,
const auto& nccl_id_wrapper = store_->get(store_key);
std::memcpy(nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
}
place_to_group_key_[p2p_key] = store_key;
}

void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
Expand Down Expand Up @@ -860,6 +862,48 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
comm_ctx->set_nccl_comm(nccl_comm);

if (FLAGS_enable_async_trace) {
// gather global ranks in current group
int* gpu_global_rank = nullptr;
size_t gpu_global_rank_size = sizeof(int);
CUDA_CHECK(cudaMalloc(&gpu_global_rank, gpu_global_rank_size));

CUDA_CHECK(cudaMemcpy(gpu_global_rank,
&global_rank_,
gpu_global_rank_size,
cudaMemcpyHostToDevice));

int* gpu_global_ranks = nullptr;
size_t gpu_global_ranks_size = num_ranks * sizeof(int);
CUDA_CHECK(cudaMalloc(&gpu_global_ranks, gpu_global_ranks_size));

NCCL_CHECK(phi::dynload::ncclAllGather(gpu_global_rank,
gpu_global_ranks,
1,
ncclInt,
nccl_comm,
comm_ctx->stream()));

std::vector<int> global_ranks(num_ranks);
CUDA_CHECK(cudaMemcpy(global_ranks.data(),
gpu_global_ranks,
gpu_global_ranks_size,
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(gpu_global_rank));
CUDA_CHECK(cudaFree(gpu_global_ranks));

// store global_ranks in current group_key
std::once_flag flag;
std::call_once(flag, [this]() {
phi::distributed::CommContextManager::GetInstance().SetStore(store_);
phi::distributed::CommTaskManager::GetInstance().SetTimeout(pg_timeout_);
});

std::string group_key = place_to_group_key_.at(place_key);
phi::distributed::CommContextManager::GetInstance().AddGroupRanks(
group_key, global_ranks);
}

auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
place_to_calc_event_.emplace(
Expand Down Expand Up @@ -913,8 +957,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if (!FLAGS_enable_async_trace) {
fn(nccl_comm, nccl_stream);
} else {
std::string group_key = place_to_group_key_.at(key);
auto comm_task =
std::make_shared<phi::distributed::NCCLCommTask>(place,
group_key,
rank_,
size_,
gid_,
Expand Down Expand Up @@ -973,22 +1019,29 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
bool is_batch_p2p = s_group_call_counter > 0;
std::string key = "";

int p2p_nrank = 0;
if (is_batch_p2p) {
key = GetKeyFromPlace(place);
p2p_rank = rank_;
p2p_target_rank = peer;
p2p_nrank = GetSize();
} else {
int low_rank = rank_ < peer ? rank_ : peer;
int high_rank = rank_ < peer ? peer : rank_;
key = std::to_string(low_rank) + "->" + std::to_string(high_rank);
p2p_rank = rank_ < peer ? 0 : 1;
p2p_target_rank = 1 - p2p_rank;
p2p_nrank = 2;
}

platform::CUDADeviceGuard cuda_guard(place);
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
CreateNCCLEnvCache(place, key, comm_type, p2p_rank);
}
if (p2p_comm_seq_.find(key) == p2p_comm_seq_.end()) {
p2p_comm_seq_[key] = 0;
}
p2p_comm_seq_[key]++;

if (!use_calc_stream) {
SyncCalcStream(place, key);
Expand All @@ -1002,18 +1055,21 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Point2Point(
auto nccl_comm = comm_ctx->nccl_comm();
auto nccl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream();

std::string group_key = place_to_group_key_.at(key);
auto comm_task =
std::make_shared<phi::distributed::NCCLCommTask>(place,
rank_,
size_,
group_key,
p2p_rank,
p2p_nrank,
gid_,
comm_seq_,
p2p_comm_seq_[key],
tensor.numel(),
sync_op,
use_calc_stream,
nccl_comm,
nccl_stream,
comm_type);
comm_type,
pg_timeout_);

if (!FLAGS_enable_async_trace) {
fn(nccl_comm, nccl_stream, p2p_target_rank);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/collective/process_group_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
place_to_comm_ctx_;

uint64_t comm_seq_{0};
std::unordered_map<std::string, uint64_t> p2p_comm_seq_;
std::unordered_map<std::string, std::string> place_to_group_key_;

// TODO(sunyilun): attrs below will be removed later
std::mutex mutex_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ message PpConfig {

message DygraphShardingConfig {
optional bool split_param = 1 [ default = false ];
optional bool comm_overlap = 2 [ default = false ];
optional int32 accumulate_steps = 3 [ default = -1 ];
}


Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/core/distributed/comm_context_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,25 @@ bool CommContextManager::Has(int ring_id) const {
return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end();
}

void CommContextManager::SetGroupSize(const std::string& pg_key, int size) {
pg_key_size_[pg_key] = size;
}

void CommContextManager::AddGroupRanks(const std::string& pg_key,
std::vector<int> global_ranks) {
if (pg_key_ranks_.find(pg_key) == pg_key_ranks_.end()) {
pg_key_ranks_[pg_key] = global_ranks;
}
}

std::vector<int> CommContextManager::GetGroupRanks(
const std::string& pg_key) const {
PADDLE_ENFORCE_NE(
pg_key_ranks_.find(pg_key),
pg_key_ranks_.end(),
errors::NotFound("Can not find pg_key %d in GroupRanks.", pg_key));
return pg_key_ranks_.at(pg_key);
}

} // namespace distributed
} // namespace phi
12 changes: 12 additions & 0 deletions paddle/phi/core/distributed/comm_context_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <iostream>
#include <memory>
#include <set>
#include <unordered_map>

#include "paddle/phi/core/distributed/comm_context.h"
Expand Down Expand Up @@ -44,6 +45,12 @@ class CommContextManager {

bool Has(int ring_id) const;

void SetGroupSize(const std::string& pg_key, int size);

void AddGroupRanks(const std::string& pg_key, std::vector<int> global_ranks);

std::vector<int> GetGroupRanks(const std::string& pg_key) const;

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
int dev_id,
Expand All @@ -64,6 +71,11 @@ class CommContextManager {

std::unordered_map<int, std::unique_ptr<CommContext>> id_to_comm_context_;
std::shared_ptr<Store> store_;

// process group key to global ranks map
std::unordered_map<std::string, std::vector<int>> pg_key_ranks_;
// process group key to group size map
std::unordered_map<std::string, int> pg_key_size_;
};

} // namespace distributed
Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/core/distributed/comm_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CommTask {
public:
CommTask(const std::string& backend = "",
const phi::Place& place = phi::Place(),
const std::string& group_key = "",
int rank = -1,
int size = 0,
int gid = 0,
Expand All @@ -47,6 +48,7 @@ class CommTask {
CommType comm_type = CommType::UNKNOWN)
: backend_(backend),
place_(place),
group_key_(group_key),
rank_(rank),
size_(size),
gid_(gid),
Expand All @@ -65,9 +67,10 @@ class CommTask {
virtual ~CommTask() = default;

std::string UniqueKey() {
return "op:" + CommTypeToString(comm_type_) +
return "group_key:" + group_key_ + ",op:" + CommTypeToString(comm_type_) +
",gid:" + std::to_string(gid_) + ",seq:" + std::to_string(seq_);
}
std::string GroupKey() { return group_key_; }
std::string GetBackend() { return backend_; }
phi::Place GetPlace() { return place_; }
int GetGlobalRank() { return global_rank_; }
Expand Down Expand Up @@ -104,6 +107,12 @@ class CommTask {
return;
}

virtual void ClearRecord() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}

virtual std::string GetCommErrors() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -124,6 +133,16 @@ class CommTask {
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual void SetUpdated(bool updated) {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return;
}
virtual bool IsUpdated() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
return false;
}
virtual void AbortComm() {
PADDLE_THROW(
phi::errors::Unimplemented("%s is not implemented.", __func__));
Expand All @@ -133,6 +152,7 @@ class CommTask {
protected:
std::string backend_;
phi::Place place_;
std::string group_key_;
int global_rank_;
int rank_;
int size_;
Expand All @@ -144,7 +164,11 @@ class CommTask {
CommType comm_type_;
bool start_trace_updated_{false};

// task status
bool started_ = false;
bool completed_ = false;
// task status changed
bool updated_ = true;
bool aborted_{false};
std::chrono::time_point<std::chrono::steady_clock> start_time_;
std::shared_ptr<Store> store_;
Expand Down
Loading

0 comments on commit 446cfd2

Please sign in to comment.