Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St][PIR] Hold backward program in GradNode #63694

Merged
merged 25 commits into from
May 6, 2024
Merged
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/fluid/eager/to_static/run_program_op_func.h
Original file line number Diff line number Diff line change
@@ -317,6 +317,26 @@ inline void pir_run_program_ad_func(
// Set Attributes
grad_node->SetAttrMap(attrs);

auto* backward_program =
PADDLE_GET_CONST(::pir::Program*, attrs.at("backward_program"));

auto* forward_program =
PADDLE_GET_CONST(::pir::Program*, attrs.at("forward_program"));

auto testkey = PADDLE_GET_CONST(std::string, attrs.at("testkey"));

VLOG(1) << "[pir_run_program_ad_func] testkey: " << testkey;
// TODO(gouzil): 对比一下这里的attrs的地址
VLOG(1) << "[pir_run_program_ad_func] attrs: " << &attrs;

VLOG(1) << "[pir_run_program_ad_func] backward_program: "
<< backward_program;
VLOG(1) << "[pir_run_program_ad_func] backward_program: "
<< backward_program->num_ops();
VLOG(1) << "[pir_run_program_ad_func] forward_program: " << forward_program;
VLOG(1) << "[pir_run_program_ad_func] forward_program: "
<< forward_program->num_ops();

// Clear unused x vars
auto filter_x = pir_filter_unused_input_var_in_backward(x_tmp, "bx", attrs);
// Set TensorWrappers
82 changes: 56 additions & 26 deletions paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
@@ -467,21 +467,18 @@ inline void PirRunProgramAPI(
auto param_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp"));

auto *forward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block"));
auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));

auto *forward_program =
forward_global_block->GetParentOp()->GetParentProgram();
PADDLE_GET_CONST(::pir::Program *, attrs.at("forward_program"));
auto *backward_program =
PADDLE_GET_CONST(::pir::Program *, attrs.at("backward_program"));

::pir::Block *forward_global_block = forward_program->block();

if (FLAGS_print_ir) {
std::ostringstream print_stream;
print_stream << "ForwardProgram is :\n";
forward_program->Print(print_stream);
if (!is_test) {
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();
print_stream << "BackwardProgram is:\n";
backward_program->Print(print_stream);
} else {
@@ -1046,11 +1043,6 @@ inline void PirRunProgramGradAPI(

VLOG(4) << "global_inner_scope:" << global_inner_scope;

auto *backward_global_block =
PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block"));
auto *backward_program =
backward_global_block->GetParentOp()->GetParentProgram();

auto output_grad_values =
PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g"));
auto forward_input_values =
@@ -1068,27 +1060,63 @@ inline void PirRunProgramGradAPI(

details::Trans2ContiguousTensorsInplace(out_grad);

auto testkey = PADDLE_GET_CONST(std::string, attrs.at("testkey"));

std::cout << "backward_program get start" << std::endl;
auto *backward_program =
PADDLE_GET_CONST(::pir::Program *, attrs.at("backward_program"));

auto *forward_program =
PADDLE_GET_CONST(::pir::Program *, attrs.at("forward_program"));

VLOG(0) << "[PirRunProgramGradAPI] testkey: " << testkey;
VLOG(0) << "[PirRunProgramGradAPI] backward_program addr: "
<< backward_program;
VLOG(0) << "[PirRunProgramGradAPI] forward_program addr: " << forward_program;
VLOG(0) << backward_program->num_ops();
VLOG(0) << forward_program->num_ops();
std::cout << "backward_program get end" << std::endl;

std::cout << "backward_program block get start" << std::endl;
auto pb = backward_program->block();
VLOG(0) << pb;
VLOG(0) << pb->num_ops();
VLOG(0) << pb->empty();
std::cout << "backward_program block get end" << std::endl;

// share x, param, middles, output_grads, out into scope.
VLOG(1) << "out_grad start";
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out_grad,
output_grad_values,
global_inner_scope);
VLOG(1) << "out_grad end";
details::ShareTensorsIntoScopeByValue(
backward_global_block, out_grad, output_grad_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, x, forward_input_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(backward_global_block,
backward_program->block(), x, forward_input_values, global_inner_scope);
VLOG(1) << "x end";
details::ShareTensorsIntoScopeByValue(backward_program->block(),
middles,
forward_middle_values,
global_inner_scope);
VLOG(1) << "middles end";
details::ShareTensorsIntoScopeByValue(backward_program->block(),
out,
forward_output_values,
global_inner_scope);
VLOG(1) << "out end";
details::ShareTensorsIntoScopeByValue(
backward_global_block, out, forward_output_values, global_inner_scope);
details::ShareTensorsIntoScopeByValue(
backward_global_block, params, parameter_values, global_inner_scope);
backward_program->block(), params, parameter_values, global_inner_scope);
VLOG(1) << "params end";

// Clear out and middles to avoid hold memory until backward finish.
out.clear();
middles.clear();
VLOG(1) << "out and middles clear end";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个忘清了?


auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;

if (!cache.Has(program_id,
global_inner_scope,
place_hash_key,
@@ -1138,10 +1166,10 @@ inline void PirRunProgramGradAPI(
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
auto skip_names = details::GetNameFromValue(
backward_global_block, x_grad_values, false, true);
backward_program->block(), x_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
skip_names = details::GetNameFromValue(
backward_global_block, p_grad_values, false, true);
backward_program->block(), p_grad_values, false, true);
skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end());
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
cache.UpdateSkipEagerDeleteVars(program_id,
@@ -1174,7 +1202,7 @@ inline void PirRunProgramGradAPI(
}
}

if (!backward_global_block->empty()) {
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
paddle::platform::TracerEventType::UserDefined,
@@ -1189,9 +1217,11 @@ inline void PirRunProgramGradAPI(
"fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1);
// Step 4. get outputs
details::ShareTensorsFromScopeByValue(
backward_global_block, x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(
backward_global_block, params_grad, p_grad_values, global_inner_scope);
backward_program->block(), x_grad, x_grad_values, global_inner_scope);
details::ShareTensorsFromScopeByValue(backward_program->block(),
params_grad,
p_grad_values,
global_inner_scope);
VLOG(4) << "after backward gc all vars";
global_inner_scope->SetCanReused(true);
details::GcScope(global_inner_scope);
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/ops_extra_info.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

@@ -977,6 +978,9 @@ struct SetAttrDescVisitor {
void operator()(const std::vector<pir::Block *> &v) const {
// just do nothing.
}
void operator()(const std::shared_ptr<pir::Program> &v) const {
// just do nothing.
}
void operator()(const std::vector<VarDesc *> &v) const {
std::vector<std::string> var_names;
for (auto var : v) {
4 changes: 3 additions & 1 deletion paddle/fluid/framework/type_defs.cc
Original file line number Diff line number Diff line change
@@ -39,7 +39,9 @@ template class variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
::pir::Program*,
ProgramDesc*>;
} // namespace paddle
REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap);
REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute);
6 changes: 5 additions & 1 deletion paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"
#include "paddle/utils/small_vector.h"
@@ -40,6 +41,7 @@ class InferShapeContext;
class InferVarTypeContext;
class VarDesc;
class BlockDesc;
class ProgramDesc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有必要嘛?下面只添加了 pir::Program,为啥这里要前向声明老 IR ProgramDesc

class Variable;
class InferNoNeedBufferVarsFN;

@@ -67,7 +69,9 @@ using Attribute = paddle::variant<paddle::blank,
paddle::experimental::Scalar,
std::vector<paddle::experimental::Scalar>,
::pir::Block*,
std::vector<::pir::Value>>;
std::vector<::pir::Value>,
::pir::Program*,
ProgramDesc*>;
using AttributeMap = std::unordered_map<std::string, Attribute>;

using OpCreator =
34 changes: 30 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/block.h"
#include "paddle/pir/include/core/op_result.h"
#include "paddle/pir/include/core/region.h"
#include "paddle/pir/include/core/value.h"

namespace paddle {
@@ -858,6 +859,30 @@ void CastPyArg2AttrIRBlock(PyObject* obj,
attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]);
}

void CastPyArg2AttrIRProgram(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
const std::string& op_type,
ssize_t arg_pos) {
VLOG(1) << "After Process pir::Program*";
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)obj; // NOLINT
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];

::pybind11::handle(obj).inc_ref();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这 inc,在哪 dec?

// ::pir::Program* program = reinterpret_cast<::pir::Program*>(vh[0]);
std::shared_ptr<::pir::Program> program =
reinterpret_cast<std::shared_ptr<::pir::Program>&>(vh[0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能拿出来吗?

// TODO(gouzil): 试一下pybind11能不能使用智能指针作为参数
// pir::IrMapping mapper;
attrs[key] = program.get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要转回裸指针?

// attrs[key] = program.get();
// attrs[key] = program->Clone(mapper);
// attrs[key] = reinterpret_cast<::pir::Program*>(vh[0]);
// attrs[key] = vh[0];
}

void CastPyArg2AttrValues(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key,
@@ -998,6 +1023,7 @@ void ConstructAttrMapForRunProgram(
attr_end));

PyObject* obj = nullptr;
attrs["testkey"] = std::string("testvalue");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

忘清了?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
VLOG(1) << "Start Process " << arg_pos;
Py_ssize_t key_len = 0;
@@ -1020,11 +1046,11 @@ void ConstructAttrMapForRunProgram(

if (std::set<std::string>({"cuda_graph_capture_mode"}).count(key)) {
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"global_block",
"forward_global_block",
"backward_global_block"})
.count(key)) {
} else if (std::set<std::string>({"global_block"}).count(key)) {
CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"forward_program", "backward_program"})
.count(key)) {
CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos);
} else if (std::set<std::string>({"is_test", "use_interpretorcore"})
.count(key)) {
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
@@ -255,7 +255,7 @@ void BindProgram(py::module *m) {
)DOC");
program
.def(py::init([]() {
return std::make_unique<Program>(pir::IrContext::Instance());
return std::make_shared<Program>(pir::IrContext::Instance());
}))
.def("__str__",
[](const std::shared_ptr<Program> &self) {
1 change: 1 addition & 0 deletions paddle/pir/src/core/program.cc
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ Program::Program(IrContext* context) {
}

Program::~Program() {
VLOG(1) << "[Program] Destroy Program";
if (module_) {
module_.Destroy();
}
8 changes: 4 additions & 4 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
@@ -898,10 +898,10 @@ def _prune_unused_params(self, program):

def _prepare_attributes(self):
attrs = [
'forward_global_block',
self.program.forward_program.global_block(),
'backward_global_block',
self.program.backward_program.global_block(),
'forward_program',
self.program.forward_program,
'backward_program',
self.program.backward_program,
'is_test',
not self.training,
'program_id',
5 changes: 4 additions & 1 deletion test/dygraph_to_static/test_no_gradient.py
Original file line number Diff line number Diff line change
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

breakpoint() # noqa: T100
import unittest

import numpy
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_pir_only

import paddle

@@ -33,6 +34,8 @@ def main_func(x, index):


class TestNoGradientCase(Dy2StTestBase):
@test_ast_only
@test_pir_only
def test_no_gradient(self):
paddle.disable_static()
x = paddle.randn([10, 3])