Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR / Dy2static】Fix ir program deconstruct bugs. #59764

Closed
Closed
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
@@ -210,6 +210,7 @@ inline void pir_run_program_ad_func(
const std::vector<paddle::Tensor>& params,
std::vector<paddle::Tensor*>& out, // NOLINT
std::vector<paddle::framework::Scope*>& step_scope, // NOLINT
const std::vector<PyObject*>& blocks_to_hold,
const paddle::framework::AttributeMap& attrs) {
// Prepare Autograd Meta
VLOG(2) << "start run pir run_program ad function.";
@@ -245,6 +246,7 @@ inline void pir_run_program_ad_func(
grad_node = std::make_shared<PirGradNodeRunProgram>(1, 2);
grad_node->GetMiddle().resize(middle_size);
grad_node->GetOutputs().resize(output_size);
grad_node->SetBlocks(blocks_to_hold);
for (size_t i = 0; i < middle_size; ++i) {
grad_node->GetMiddle()[i] =
paddle::Tensor(std::make_shared<phi::DenseTensor>());
105 changes: 62 additions & 43 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

#pragma once

#include <Python.h>
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
@@ -456,21 +457,16 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
auto *forward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_program")));
auto *backward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_program")));

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
@@ -496,9 +492,9 @@ inline void PirRunProgramAPI(
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// Step 2. create new interpretercore
auto kernel_forward_program =
paddle::dialect::PdOpLowerToKernelPass(forward_program, place);
@@ -514,20 +510,20 @@ inline void PirRunProgramAPI(
// *backward_program);

// update interpretercore skip_gc_var
auto skip_names =
details::GetNameFromValue(forward_global_block, middle_values, false);
auto skip_names = details::GetNameFromValue(
forward_program->block(), middle_values, false);
auto skip_names_set =
std::set<std::string>(skip_names.begin(), skip_names.end());
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
attrs.at("no_need_buffers"));
auto no_need_buffer_names = details::GetNameFromValue(
forward_global_block, no_need_buffer_values, false);
forward_program->block(), no_need_buffer_values, false);
for (auto &name : no_need_buffer_names) {
VLOG(4) << "Find no need buffer vars with name:" << name;
skip_names_set.erase(name);
}
skip_names =
details::GetNameFromValue(forward_global_block, output_values, false);
skip_names = details::GetNameFromValue(
forward_program->block(), output_values, false);
skip_names_set.insert(skip_names.begin(), skip_names.end());
details::print_collection(skip_names_set);
interpreter_core->SetSkipGcVars(skip_names_set);
@@ -550,9 +546,9 @@ inline void PirRunProgramAPI(
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScopeByValue(
forward_global_block, x, input_values, global_inner_scope);
forward_program->block(), x, input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
forward_global_block, params, param_values, global_inner_scope);
forward_program->block(), params, param_values, global_inner_scope);
// TODO(xiongkun): new ir how to build scope.
// if (interpreter_core->GetVariableScope()->GetMutableScope() !=
// global_inner_scope) {
@@ -563,7 +559,7 @@ inline void PirRunProgramAPI(
}

// interpretercore run
if (!forward_global_block->empty()) {
if (!forward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
@@ -576,9 +572,9 @@ inline void PirRunProgramAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Get Output, and Middle Outputs
details::ShareTensorsFromScopeByValue(
forward_global_block, out, output_values, global_inner_scope);
forward_program->block(), out, output_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
forward_global_block, middles, middle_values, global_inner_scope);
forward_program->block(), middles, middle_values, global_inner_scope);

VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());

@@ -991,10 +987,8 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
auto *backward_program = reinterpret_cast<::pir::Program *>(
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_program")));

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
@@ -1012,18 +1006,22 @@ inline void PirRunProgramGradAPI(
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp_g"));

// share x, param, middles, output_grads, out into scope.
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
backward_program->block(), x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
middles,
forward_middle_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out,
forward_output_values,
global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);
backward_program->block(), params, parameter_values, global_inner_scope);

auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance();
@@ -1060,11 +1058,11 @@ inline void PirRunProgramGradAPI(

// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names =
details::GetNameFromValue(backward_global_block, x_grad_values, false);
auto skip_names = details::GetNameFromValue(
backward_program->block(), x_grad_values, false);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names =
details::GetNameFromValue(backward_global_block, p_grad_values, false);
skip_names = details::GetNameFromValue(
backward_program->block(), p_grad_values, false);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
@@ -1086,14 +1084,11 @@ inline void PirRunProgramGradAPI(

if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) {
// update scope (TODO(xiongkun): do we need this??)
// details::BuildScopeByBlock(
// *interpreter_core.get(), *backward_global_block, global_inner_scope);
interpreter_core->reset_scope(global_inner_scope);
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
@@ -1108,9 +1103,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
@@ -1309,6 +1306,7 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
}
middles_.clear();
outputs_.clear();
ClearBlocks();
}
}
// Functor: perform backward computations
@@ -1371,6 +1369,10 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
params_grad_ptr);
VLOG(3) << "End Eager Backward Node: PirGradNodeRunProgram";

egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&x_grad,
this->OutputMeta()[0]);
egr::EagerUtils::FillZeroForEmptyOptionalGradOutput(&params_grad,
this->OutputMeta()[1]);
executed_ = true;
return {x_grad, params_grad};
}
@@ -1456,16 +1458,33 @@ class PirGradNodeRunProgram : public egr::GradNodeBase {
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::shared_ptr<PirGradNodeRunProgram>(
new PirGradNodeRunProgram(*this));
copied_node->SetBlocks(blocks_);
return copied_node;
}

public:
void SetBlocks(const std::vector<PyObject *> blocks) {
blocks_ = blocks;
for (auto &obj : blocks_) {
VLOG(4) << "program is not NULL, we increase the program ref counter.";
Py_INCREF(obj);
}
}

void ClearBlocks() {
for (auto &obj : blocks_) {
Py_DECREF(obj);
}
}

private:
// TensorWrappers
std::vector<paddle::Tensor> x_;
std::vector<paddle::Tensor> params_;
std::vector<paddle::Tensor> middles_;
std::vector<paddle::Tensor> outputs_;
std::vector<paddle::framework::Scope *> step_scope_;
std::vector<PyObject *> blocks_;

// Attribute Map
paddle::framework::AttributeMap attrs_;
5 changes: 3 additions & 2 deletions paddle/fluid/pybind/eager_legacy_custom_python_api.h
Original file line number Diff line number Diff line change
@@ -88,12 +88,13 @@ static PyObject *pir_eager_api_run_program(PyObject *self,
// TODO(zengjinle): support CUDA Graph on eager mode
VLOG(1) << "Start Pir ConstructAttrMapFromPyArgs";

std::vector<PyObject *> block_objs;
ConstructAttrMapForRunProgram(
"run_program", args, 5, PyTuple_GET_SIZE(args), attrs);
"run_program", args, 5, PyTuple_GET_SIZE(args), block_objs, attrs);

VLOG(1) << "Finish Pir ConstructAttrMapFromPyArgs";
tstate = PyEval_SaveThread();
pir_run_program_ad_func(X, Params, Out, OutScope, attrs);
pir_run_program_ad_func(X, Params, Out, OutScope, block_objs, attrs);
PyEval_RestoreThread(tstate);
tstate = nullptr;
Py_RETURN_NONE;
7 changes: 4 additions & 3 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
@@ -976,6 +976,7 @@ void ConstructAttrMapForRunProgram(
PyObject* args,
ssize_t attr_start,
ssize_t attr_end,
std::vector<PyObject*>& blocks_to_hold, // NOLINT
paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2,
0,
@@ -1008,11 +1009,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
} else if (std::set<std::string>(
{"global_block", "forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
blocks_to_hold.push_back(obj);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_common.h
Original file line number Diff line number Diff line change
@@ -199,6 +199,7 @@ void ConstructAttrMapForRunProgram(
PyObject* args,
ssize_t attr_start,
ssize_t attr_end,
std::vector<PyObject*>& blocks_to_hold, // NOLINT
paddle::framework::AttributeMap& attrs); // NOLINT

unsigned long GetUnsignedLongFromArgs( // NOLINT
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
@@ -848,10 +848,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
1 change: 0 additions & 1 deletion test/dygraph_to_static/test_len.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,6 @@ def test_len_legacy(self):
)
self.assertEqual(selected_rows_var_len, var_tensor_len)

@test_pir_only
@test_ast_only
@test_pir_only
def test_len(self):
6 changes: 5 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,10 @@
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle

@@ -33,6 +36,7 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])