Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[phi] support pir run in dy2static AST mode. #57357

Merged
merged 32 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d647a5e
[NewIR] Support Ir run program node (#56791)
2742195759 Sep 14, 2023
d912142
new pr
2742195759 Sep 15, 2023
b8ccffc
fix
2742195759 Sep 15, 2023
442f83c
fix
2742195759 Sep 15, 2023
45fe960
fix segment error
2742195759 Sep 15, 2023
6bf5107
merge
2742195759 Sep 15, 2023
43e5709
fix
2742195759 Sep 15, 2023
93bcbe1
add dependences
2742195759 Sep 16, 2023
92efa87
Merge remote-tracking branch 'upstream/develop' into revert-revert-ir…
2742195759 Sep 16, 2023
c09974e
fix
2742195759 Sep 16, 2023
3e955ea
fix link error.
2742195759 Sep 16, 2023
378f9fb
fix some cmake problem
2742195759 Sep 17, 2023
ba82434
Merge remote-tracking branch 'upstream/develop' into revert-revert-ir…
2742195759 Sep 17, 2023
0f99731
fix
2742195759 Sep 17, 2023
4910330
fix
2742195759 Sep 18, 2023
8568b62
Merge remote-tracking branch 'upstream/develop' into revert-revert-ir…
2742195759 Sep 18, 2023
5df6766
fix dependecy
2742195759 Sep 18, 2023
54ac315
fix
2742195759 Sep 18, 2023
3fce816
fix
2742195759 Sep 18, 2023
7253c40
fix circle dependence
2742195759 Sep 18, 2023
c0f1fee
fix
2742195759 Sep 18, 2023
49289da
fix
2742195759 Sep 18, 2023
9a6fff0
fix merge
2742195759 Sep 19, 2023
71e9e4a
fix rocm
2742195759 Sep 19, 2023
4a4696a
fix
2742195759 Sep 19, 2023
ec8b675
add python library
2742195759 Sep 20, 2023
0eec397
fix cmake
2742195759 Sep 20, 2023
33cd117
Merge remote-tracking branch 'upstream/develop' into revert-revert-ir…
2742195759 Sep 20, 2023
c3a8fd0
merge
2742195759 Sep 20, 2023
4d59616
fix
2742195759 Sep 20, 2023
99b33cc
fix
2742195759 Sep 20, 2023
af69ca3
fix conflict
2742195759 Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ inline void run_program_ad_func(
std::vector<paddle::Tensor*>& dout, // NOLINT
const paddle::framework::AttributeMap& attrs) {
// Prepare Autograd Meta
VLOG(2) << "start run run_program ad function.";
auto deref_out = details::DereferenceTensors(out);
std::vector<egr::AutogradMeta*> p_autograd_x =
egr::EagerUtils::nullable_autograd_meta(x);
Expand Down Expand Up @@ -197,3 +198,107 @@ inline void run_program_ad_func(
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node);
}
}

inline void newir_run_program_ad_func(
const std::vector<paddle::Tensor>& x,
const std::vector<paddle::Tensor>& params,
std::vector<paddle::Tensor*>& out, // NOLINT
std::vector<paddle::framework::Scope*>& step_scope, // NOLINT
std::vector<paddle::Tensor*>& dout, // NOLINT
const paddle::framework::AttributeMap& attrs) {
// Prepare Autograd Meta
VLOG(2) << "start run newir run_program ad function.";
auto deref_out = details::DereferenceTensors(out);
std::vector<egr::AutogradMeta*> p_autograd_x =
egr::EagerUtils::nullable_autograd_meta(x);
std::vector<egr::AutogradMeta*> p_autograd_params =
egr::EagerUtils::nullable_autograd_meta(params);
std::vector<egr::AutogradMeta*> p_autograd_outs =
egr::EagerUtils::nullable_autograd_meta(deref_out);

bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, &p_autograd_x, &p_autograd_params);

// Create Middle Output for GradNode.
auto middle_size =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size();
auto output_size =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size();
auto middles = std::vector<paddle::Tensor*>();
std::shared_ptr<NewIRGradNodeRunProgram> grad_node;
VLOG(2) << "start run run_program with require_any_grad = "
<< require_any_grad;

if (require_any_grad) {
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
grad_node = std::make_shared<NewIRGradNodeRunProgram>(1, 2);
grad_node->GetMiddle().resize(middle_size);
grad_node->GetOutputs().resize(output_size);
for (size_t i = 0; i < middle_size; ++i) {
grad_node->GetMiddle()[i] =
paddle::Tensor(std::make_shared<phi::DenseTensor>());
middles.push_back(&grad_node->GetMiddle()[i]);
}
for (size_t i = 0; i < output_size; ++i) {
grad_node->GetOutputs()[i] = *out[i];
}
}

// Call forward function
// if require_any_grad is False, don't save any middle vars.
NewIRRunProgramAPI(
x, params, out, middles, step_scope, dout, require_any_grad, attrs);
if (require_any_grad) {
// auto x_names =
// PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));

egr::EagerUtils::PassStopGradient(false, &p_autograd_outs);

// Set Attributes
grad_node->SetAttrMap(attrs);

// auto* forward_global_block = PADDLE_GET_CONST(
// paddle::framework::BlockDesc*, attrs.at("forward_global_block"));
// auto* backward_global_block = PADDLE_GET_CONST(
// paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
// auto filter_x =
// filter_unused_input_var_in_backward(x, x_names, backward_global_block);
// Set TensorWrappers
grad_node->SetFwdX(x);
// Clear unused out vars
// clear_unused_out_var_in_backward(out, backward_global_block,
// step_scope[0]);

grad_node->SetFwdParams(params);
grad_node->SetStepScope(step_scope); // just for set useable.

// Set Grad out rank as same as fwd input and set stop gradient to bwd
// NOTE(@xiongkun): Not every tensor in x(list of tensor) is required
// gradient. for example: x[1] is not used for output, the x[1] is ignored.

// TODO(@xiongkun): rewrite by new ir representation.
std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
x_require_grad.push_back(&x[i]);
}

grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0);
grad_node->SetGradOutMeta(params, /*slot id*/ 1);

// VLOG(2) << "clear_no_grad_edges.";
// clear_no_grad_edges_with_partial_block(params,
// forward_global_block,
// backward_global_block,
// grad_node.get(),
// [>slot id<] 1);

grad_node->SetGradInMeta(deref_out, 0);

egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0);

// Set History for output set current Grad Node for
egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node);
}
}
Loading