Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR][AutoParallel] Support merge fetch results among micro_scopes #58528

Merged
merged 11 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 190 additions & 29 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <map>
#include <vector>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/new_executor/feed_fetch_utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

Expand Down Expand Up @@ -48,42 +49,202 @@ void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
}
}

void SplitFeedTensor(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out) {
if (micro_batch_num < 2) return;

out->resize(micro_batch_num);
void SplitFeedTensors(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out) {
std::vector<phi::DenseTensor> feed_tensors;
for (size_t i = 0; i < feed_names.size(); ++i) {
auto feed_name = feed_names[i];
auto feed_var = scope->GetVar(feed_name);
PADDLE_ENFORCE_NOT_NULL(
feed_var,
platform::errors::NotFound("Variable %s should not be nullptr.",
feed_names[i]));
feed_tensors.push_back(feed_var->Get<phi::DenseTensor>());
}

out->resize(micro_batch_num);
if (micro_batch_num < 2) {
(*out)[0] = std::move(feed_tensors);
return;
}

for (size_t i = 0; i < feed_tensors.size(); ++i) {
auto& feed_tensor = feed_tensors[i];
int64_t numel_size = feed_tensor.dims()[0];
PADDLE_ENFORCE_EQ(numel_size % micro_batch_num,
0,
phi::errors::InvalidArgument(
"Split expects feed data (%s)'s dim[0] (%d) is "
"diviable by micro_batch_num (%d).",
feed_names[i],
numel_size,
micro_batch_num));
int64_t split_size = numel_size / micro_batch_num;
VLOG(4) << "Split feed data:" << feed_names[i] << ", dims:("
<< feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num;
for (int64_t j = 0; j < micro_batch_num; ++j) {
(*out)[j].resize(i + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接分配feed_tensors.size个元素空间

(*out)[j][i].ShareDataWith(
feed_tensor.Slice(j * split_size, j * split_size + split_size));
}
}
}

void FetchTensors(const std::vector<std::string>& job_fetch_names,
const std::vector<std::string>& fetch_var_names,
const int64_t micro_batch_id,
Scope* scope,
FetchUnmergedList* fetch_list) {
PADDLE_ENFORCE_GT(
fetch_list->size(),
micro_batch_id,
phi::errors::Unavailable("The fetch list size (%lld) should be greater "
"than micro_batch_id (%lld)",
fetch_list->size(),
micro_batch_id));

fetch_list->at(micro_batch_id).resize(fetch_var_names.size());
for (auto& var_name : job_fetch_names) {
int col = find(fetch_var_names.begin(), fetch_var_names.end(), var_name) -
fetch_var_names.begin();
auto* var = scope->FindVar(var_name);
auto& src = var->Get<phi::DenseTensor>();
auto* dst =
&(PADDLE_GET(phi::DenseTensor, fetch_list->at(micro_batch_id)[col]));
TensorCopy(src, platform::CPUPlace(), dst);
}
}

void MergeFetchTensors(const FetchUnmergedList& fetch_list,
const int64_t micro_batch_num,
FetchList* out) {
if (fetch_list.size() == 0) return;

PADDLE_ENFORCE_EQ(
fetch_list.size(),
micro_batch_num,
phi::errors::Unavailable("The fetch_list size (%lld) shoule be equal to "
"the micro_batch_num (%lld)",
fetch_list.size(),
micro_batch_num));

if (feed_var->IsType<phi::DenseTensor>()) {
phi::DenseTensor feed_tensor = feed_var->Get<phi::DenseTensor>();
int64_t numel_size = feed_tensor.dims()[0];
PADDLE_ENFORCE_EQ(numel_size % micro_batch_num,
0,
platform::errors::InvalidArgument(
"Split expects feed data (%s)'s dim[0] (%d) is "
"diviable by micro_batch_num (%d).",
feed_name,
numel_size,
micro_batch_num));
int64_t split_size = (numel_size + micro_batch_num - 1) / micro_batch_num;
VLOG(4) << "Split feed data:" << feed_name << ", dims:("
<< feed_tensor.dims() << "), micro_batch_num:" << micro_batch_num;
for (int64_t j = 0; j < micro_batch_num; ++j) {
(*out)[j].resize(i + 1);
(*out)[j][i].ShareDataWith(
feed_tensor.Slice(j * split_size, j * split_size + split_size));
if (micro_batch_num < 2) {
*out = std::move(fetch_list[0]);
return;
}

out->resize(fetch_list[0].size());
for (size_t i = 0; i < fetch_list[0].size(); ++i) {
std::vector<const phi::DenseTensor*> tensors_ptr;
for (auto micro_batch_id = 0; micro_batch_id < micro_batch_num;
++micro_batch_id) {
tensors_ptr.push_back(
&PADDLE_GET_CONST(phi::DenseTensor, fetch_list[micro_batch_id][i]));
}
phi::DenseTensor merged_tensor;
MergeTensors(tensors_ptr, platform::CPUPlace(), &merged_tensor);
out->at(i) = std::move(merged_tensor);
}
}

void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors,
const platform::Place dst_place,
phi::DenseTensor* target) {
PADDLE_ENFORCE_EQ(
tensors.empty(),
false,
phi::errors::InvalidArgument("The tensors to be merged are empty."));

DDim new_dim = tensors[0]->dims();
proto::VarType::Type new_type = proto::VarType::FP32;
phi::DataLayout new_layout = tensors[0]->layout();
for (auto* t : tensors) {
if (t->numel() && t->IsInitialized()) {
new_dim = t->dims();
new_type = framework::TransToProtoVarType(t->dtype());
new_layout = t->layout();
break;
}
}

auto rank = tensors[0]->dims().size();
if (rank == 0) {
std::vector<int> init_shape = {1};
new_dim = new_dim.reshape(init_shape);
}

for (size_t i = 1; i < tensors.size(); ++i) {
auto* t = tensors[i];
if (t->numel() && t->IsInitialized()) {
PADDLE_ENFORCE_EQ(
new_type,
framework::TransToProtoVarType(t->dtype()),
phi::errors::InvalidArgument(
"phi::DenseTensor data type does not match, expected type is %s, "
"actual "
"type is %s.",
DataTypeToString(new_type),
DataTypeToString(framework::TransToProtoVarType(t->dtype()))));
PADDLE_ENFORCE_EQ(
new_layout,
t->layout(),
phi::errors::InvalidArgument(
"phi::DenseTensor layout does not match, expected layout is %s, "
"actual layout is %s.",
phi::DataLayoutToString(new_layout),
phi::DataLayoutToString(t->layout())));
if (rank > 0) {
auto tensor_dims = t->dims();
PADDLE_ENFORCE_EQ(tensor_dims.size(),
new_dim.size(),
phi::errors::InvalidArgument(
"dimensions of DenseTensor does not match"));
for (int j = 1; j < t->dims().size(); j++) {
PADDLE_ENFORCE_EQ(
tensor_dims[j],
new_dim[j],
phi::errors::InvalidArgument(
"DenseTensor.ddim[%d] should eaqual to %d, but is %d",
j,
new_dim[j],
tensor_dims[j]));
}
new_dim[0] += t->dims()[0];
} else if (rank == 0) {
auto tensor_dims = t->dims();
PADDLE_ENFORCE_EQ(tensor_dims.size(),
0,
phi::errors::InvalidArgument(
"dimensions of DenseTensor does not match"));
PADDLE_ENFORCE_EQ(new_dim.size(),
1,
phi::errors::InvalidArgument(
"dimensions of DenseTensor does not match"));
new_dim[0] += 1;
}
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Type (%s) not support in SplitFeedTensor.",
ToTypeName(feed_var->Type())));
}
}

target->Resize(new_dim);
target->set_layout(new_layout);
target->mutable_data(dst_place, TransToPhiDataType(new_type));

int begin = 0;
for (auto* src : tensors) {
int src_dim = 1;
if (src->dims()[0] > 0) {
src_dim = src->dims()[0];
}
int end = static_cast<int>(begin + src_dim);
if (end == begin) {
continue;
}
auto dst = target->Slice(begin, end);
TensorCopy(*src, dst_place, &dst);
begin = end;
}
}

} // namespace framework
Expand Down
22 changes: 18 additions & 4 deletions paddle/fluid/framework/new_executor/feed_fetch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,24 @@ void SetColAttrForFeedFetchOps(std::shared_ptr<ProgramDesc> program_desc,
const int64_t micro_batch_num,
const int64_t micro_batch_id);

void SplitFeedTensor(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out);
void SplitFeedTensors(const std::vector<std::string>& feed_names,
const int64_t micro_batch_num,
Scope* scope,
std::vector<std::vector<phi::DenseTensor>>* out);

void FetchTensors(const std::vector<std::string>& job_fetch_names,
const std::vector<std::string>& fetch_var_names,
const int64_t micro_batch_id,
Scope* scope,
FetchUnmergedList* fetch_list);

void MergeFetchTensors(const FetchUnmergedList& fetch_list,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void MergeFetchTensors(const FetchUnmergedList& fetch_list,
void MergeFetchList(const FetchUnmergedList& fetch_list,

const int64_t micro_batch_num,
FetchList* out);

void MergeTensors(const std::vector<const phi::DenseTensor*>& tensors,
const platform::Place dst_place,
phi::DenseTensor* target);

} // namespace framework
} // namespace paddle
25 changes: 7 additions & 18 deletions paddle/fluid/framework/new_executor/interpreter/job.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,10 @@ class Job final {

const std::string& Type() const { return type_; }

int ColAttrForFetchOp(int fetch_op_id) const {
return fetch_op_id_to_col_attr_.at(fetch_op_id);
}

int64_t MicroBatchId() const { return micro_batch_id_; }

std::set<std::string> SkipGcVars() const { return skip_gc_vars_; }

std::vector<int> AllFetchOpIds() const {
std::vector<int> fetch_op_ids;
fetch_op_ids.reserve(fetch_op_id_to_col_attr_.size());
for (auto& item : fetch_op_id_to_col_attr_) {
fetch_op_ids.push_back(item.first);
}
return fetch_op_ids;
}

void SetColAttrForFetchOp(int fetch_op_id, int col_attr) {
fetch_op_id_to_col_attr_[fetch_op_id] = col_attr;
}

void SetMicroBatchId(int64_t micro_batch_id) {
PADDLE_ENFORCE_GE(
micro_batch_id,
Expand All @@ -71,11 +54,17 @@ class Job final {
skip_gc_vars_ = skip_gc_vars;
}

void SetFetchVarName(const std::string& fetch_var_name) {
fetch_var_names_.push_back(fetch_var_name);
}

std::vector<std::string> FetchVarNames() { return fetch_var_names_; }

private:
const std::string type_;
int64_t micro_batch_id_;
std::unordered_map<int, int> fetch_op_id_to_col_attr_;
std::set<std::string> skip_gc_vars_;
std::vector<std::string> fetch_var_names_;
};

} // namespace interpreter
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/new_executor/interpreter_base_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class InterpreterBaseImpl {
virtual ~InterpreterBaseImpl() = default;
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) = 0;
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true) = 0;

virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ InterpreterCore::~InterpreterCore() {

FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
return impl_->Run(feed_names, feed_tensors);
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch) {
return impl_->Run(feed_names, feed_tensors, need_fetch);
}

FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class InterpreterCore {

paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors);
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true);

paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true);
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,8 @@ void PirInterpreter::ConstructEventForJitInput() {

paddle::framework::FetchList PirInterpreter::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch) {
auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
Expand Down Expand Up @@ -1100,10 +1101,12 @@ paddle::framework::FetchList PirInterpreter::Run(
if (FLAGS_enable_new_ir_in_executor) {
framework::FetchList fetch_res;

for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(0) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
if (need_fetch) {
for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}
}

VLOG(4) << "get fetch list size: " << fetch_res.size();
Expand Down Expand Up @@ -1192,7 +1195,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
if (need_fetch) {
for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(0) << "fetch " << var_name << "[" << var << "]";
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/new_executor/pir_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class PirInterpreter : public InterpreterBaseImpl {

paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) override;
const std::vector<phi::DenseTensor>& feed_tensors,
bool need_fetch = true) override;

paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
Expand Down
Loading