-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR][AutoParallel] Support merge fetch results among micro_scopes #58528
[PIR][AutoParallel] Support merge fetch results among micro_scopes #58528
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
Scope* scope, | ||
FetchUnmergedList* fetch_list); | ||
|
||
void MergeFetchTensors(const FetchUnmergedList& fetch_list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void MergeFetchTensors(const FetchUnmergedList& fetch_list, | |
void MergeFetchList(const FetchUnmergedList& fetch_list, |
@@ -71,11 +54,17 @@ class Job final { | |||
skip_gc_vars_ = skip_gc_vars; | |||
} | |||
|
|||
void SetFetchVarName(std::string fetch_var_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void SetFetchVarName(std::string fetch_var_name) { | |
void SetFetchVarName(const std::string& fetch_var_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (need_fetch) { | ||
for (auto& var_name : fetch_var_names_) { | ||
auto* var = inner_scope->FindVar(var_name); | ||
VLOG(0) << "fetch " << var_name << "[" << var << "]"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VLOG(0) << "fetch " << var_name << "[" << var << "]"; | |
VLOG(6) << "fetch " << var_name << "[" << var << "]"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
feed_names[i], | ||
numel_size, | ||
micro_batch_num)); | ||
int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num; | |
int64_t split_size = numel_size / micro_batch_num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
VLOG(4) << "Split feed data:" << feed_names[i] << ", dims:(" | ||
<< feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num; | ||
for (int64_t j = 0; j < micro_batch_num; ++j) { | ||
(*out)[j].resize(i + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以直接分配feed_tensors.size个元素空间
FetchList* out); | ||
|
||
void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors, | ||
platform::Place dst_place, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
platform::Place dst_place, | |
const platform::Place& dst_place, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…addlePaddle#58528) * [PIR][AutoParallel] Support MicroScope Merge Fetch Result * fix func name * revert plan * fix feed * tiny fix * rm print&update ut * rm comment * fix splitfeed & ut * add mergetensor func * fix ut * update split_size
…addlePaddle#58528) * [PIR][AutoParallel] Support MicroScope Merge Fetch Result * fix func name * revert plan * fix feed * tiny fix * rm print&update ut * rm comment * fix splitfeed & ut * add mergetensor func * fix ut * update split_size
PR types
Others
PR changes
Others
Description
Pcard-76459
FetchTensors
to get all multi-scope fetch resultsMergeFetchTensors
to merge fetch results of the same name and different scope