-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fleet executor] Comm init for dist model inf #39012
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,9 @@ | |
#include <glog/logging.h> | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/dist_model.h" | ||
#include "paddle/fluid/framework/block_desc.h" | ||
#include "paddle/fluid/framework/naive_executor.h" | ||
#include "paddle/fluid/framework/op_proto_maker.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
#include "paddle/fluid/framework/scope.h" | ||
#include "paddle/fluid/framework/tensor.h" | ||
|
@@ -37,24 +39,173 @@ bool IsPersistable(const framework::VarDesc *var) { | |
|
||
bool DistModel::Init() { | ||
/* TODO(fleet exe dev): implement this funct */ | ||
place_ = paddle::platform::CUDAPlace(config_.device_id); | ||
if (!PrepareScope()) { | ||
return false; | ||
bool init_method = (!config_.model_dir.empty() || config_.program_desc); | ||
PADDLE_ENFORCE_EQ(init_method, true, | ||
platform::errors::InvalidArgument( | ||
"One of model dir or program desc must be provided to " | ||
"dist model inference.")); | ||
if (config_.program_desc) { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
config_.scope, platform::errors::InvalidArgument( | ||
"Scope must be provided to dist model inference if " | ||
"program desc has been provided.")); | ||
} | ||
if (!PrepareProgram()) { | ||
if (!PreparePlace()) { | ||
return false; | ||
} | ||
if (!config_.program_desc) { | ||
if (config_.scope) { | ||
LOG(WARNING) << "The provided scope will be ignored if model dir has " | ||
"also been provided."; | ||
} | ||
if (!PrepareScope()) { | ||
return false; | ||
} | ||
if (!PrepareProgram()) { | ||
return false; | ||
} | ||
} else { | ||
program_.reset(config_.program_desc); | ||
scope_.reset(config_.scope); | ||
} | ||
if (!CommInit()) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
bool DistModel::PreparePlace() { | ||
if (config_.place == "GPU") { | ||
place_ = paddle::platform::CUDAPlace(config_.device_id); | ||
} else if (config_.place == "CPU") { | ||
place_ = paddle::platform::CPUPlace(); | ||
} else { | ||
PADDLE_THROW(platform::errors::InvalidArgument( | ||
"Place must be choosen from GPU or CPU, but got %s.", config_.place)); | ||
} | ||
return true; | ||
} | ||
|
||
bool DistModel::CommInit() { | ||
// TODO(fleet executor): init the comm | ||
// NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption | ||
// that mp part is always on inner side and pp part is always on outer side. | ||
// TODO(fleet exe dev): The peer endpoints could be configured by users. | ||
PADDLE_ENFORCE_EQ( | ||
config_.pp_degree * config_.mp_degree, config_.nranks, | ||
platform::errors::InvalidArgument( | ||
"The mp_degree multiplies pp_degree is not equal with nranks")); | ||
std::unique_ptr<framework::ProgramDesc> comm_init_program( | ||
new framework::ProgramDesc()); | ||
framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0); | ||
if (config_.mp_degree > 1) { | ||
PADDLE_ENFORCE_GE( | ||
config_.mp_ring_id, 0, | ||
platform::errors::InvalidArgument( | ||
"mp ring id must be provided for inference under mp.")); | ||
VLOG(3) << "Init comm group for mp."; | ||
std::vector<std::string> peer_endpoints; | ||
for (int64_t | ||
idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree, | ||
i = 0; | ||
i < config_.mp_degree; ++idx, ++i) { | ||
if (config_.trainer_endpoints[idx] == config_.current_endpoint) { | ||
continue; | ||
} | ||
peer_endpoints.emplace_back(config_.trainer_endpoints[idx]); | ||
} | ||
// get nranks in a mp group and inner group rank for local rank | ||
int64_t mp_group_nranks = config_.nranks / config_.pp_degree; | ||
int64_t mp_group_rank = config_.local_rank % config_.mp_degree; | ||
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints, | ||
comm_init_block, config_.mp_ring_id); | ||
} | ||
if (config_.pp_degree) { | ||
// NOTE: the last pp stage doesn't need init pp comm | ||
VLOG(3) << "Init comm group for pp."; | ||
if (config_.local_rank - config_.mp_degree >= 0) { | ||
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true, | ||
platform::errors::InvalidArgument( | ||
"pp upstream ring id must be provided for " | ||
"non-first pp stage if inference under pp.")); | ||
// not the first pp stage, has upstream | ||
std::vector<std::string> upstream_peer_endpoints; | ||
upstream_peer_endpoints.emplace_back( | ||
config_.trainer_endpoints[config_.local_rank - config_.mp_degree]); | ||
InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints, | ||
comm_init_block, config_.pp_upstream_ring_id); | ||
} | ||
|
||
if (config_.local_rank + config_.mp_degree < config_.nranks) { | ||
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true, | ||
platform::errors::InvalidArgument( | ||
"pp downstream ring id must be provided for " | ||
"non-last pp stage if inference under pp.")); | ||
// not the last pp stage, has downstream | ||
std::vector<std::string> downstream_peer_endpoints; | ||
downstream_peer_endpoints.emplace_back( | ||
config_.trainer_endpoints[config_.local_rank + config_.mp_degree]); | ||
InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints, | ||
comm_init_block, config_.pp_downstream_ring_id); | ||
} | ||
} | ||
framework::NaiveExecutor e(place_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其实可以不用executor执行op来跑的,直接掉api就行,不过这样也没啥问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样比较简洁吧,以后如果需要其它op可以直接加在这里 |
||
e.CreateVariables(*comm_init_program, 0, true, scope_.get()); | ||
e.Prepare(scope_.get(), *comm_init_program, 0, false); | ||
e.Run(); | ||
VLOG(3) << "Comm init successful."; | ||
return true; | ||
} | ||
|
||
void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank, | ||
const std::vector<std::string> &peer_endpoints, | ||
framework::BlockDesc *block, int ring_id) { | ||
/* | ||
* tmp_var_name: the var name for var comm_id | ||
* nranks: number of total ranks | ||
* rank: the rank of local rank in the comm group | ||
* peer_endpoints: peer's endpoints | ||
* block: the block where to insert the comm ops | ||
* ring_id: the ring_id to be inited | ||
*/ | ||
std::string &endpoint = config_.current_endpoint; | ||
std::stringstream ss; | ||
ss << "Init comm with tmp var: " << tmp_var_name | ||
<< ". The ring id is: " << ring_id << ". The group has: " << nranks | ||
<< " ranks. Current rank in the group is: " << rank | ||
<< ". The endpoint is: " << endpoint << ". Peer endpoints are: "; | ||
for (auto ep : peer_endpoints) { | ||
ss << ep << ", "; | ||
} | ||
VLOG(3) << ss.str(); | ||
if (config_.place == "GPU") { | ||
framework::VarDesc *new_var = block->Var(tmp_var_name); | ||
new_var->SetType(framework::proto::VarType::RAW); | ||
new_var->SetPersistable(true); | ||
framework::OpDesc *gen_nccl_id_op = block->AppendOp(); | ||
gen_nccl_id_op->SetType("c_gen_nccl_id"); | ||
gen_nccl_id_op->SetOutput("Out", {tmp_var_name}); | ||
gen_nccl_id_op->SetAttr("rank", rank); | ||
gen_nccl_id_op->SetAttr("endpoint", config_.current_endpoint); | ||
gen_nccl_id_op->SetAttr("other_endpoints", peer_endpoints); | ||
gen_nccl_id_op->SetAttr("ring_id", ring_id); | ||
gen_nccl_id_op->SetAttr("op_role", | ||
static_cast<int>(framework::OpRole::kForward)); | ||
gen_nccl_id_op->CheckAttrs(); | ||
framework::OpDesc *comm_init_op = block->AppendOp(); | ||
comm_init_op->SetType("c_comm_init"); | ||
comm_init_op->SetInput("X", {tmp_var_name}); | ||
comm_init_op->SetAttr("rank", rank); | ||
comm_init_op->SetAttr("nranks", nranks); | ||
comm_init_op->SetAttr("ring_id", ring_id); | ||
comm_init_op->SetAttr("op_role", | ||
static_cast<int>(framework::OpRole::kForward)); | ||
comm_init_op->CheckAttrs(); | ||
} else { | ||
LOG(WARNING) << "DistModelInf doesn't init comm."; | ||
// TODO(fleet exe dev): comm init for more devices | ||
} | ||
} | ||
|
||
bool DistModel::PrepareScope() { | ||
scope_.reset(new framework::Scope()); | ||
return true; | ||
|
@@ -119,6 +270,8 @@ bool DistModel::LoadParameters() { | |
new_var->SetLoDLevel(var->GetLoDLevel()); | ||
new_var->SetPersistable(true); | ||
params.push_back(new_var->Name()); | ||
// NOTE: if the params are stored in different files, 'load' op should be | ||
// added here | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记得CoordSys吗,后面最好抽象一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉没啥必要,inf组网的维度最多只有pp和mp,为了这两个再搞一个coord sys感觉有点多余。其实主要是之前把c++端的coord sys移到python 端了。。。不想再移回来😂