Skip to content

Commit

Permalink
【Fix PIR Unittest No.66】refine pir unique_name and set_parameter (#64312
Browse files Browse the repository at this point in the history
)

* refine pir unique_name and set_parameter
  • Loading branch information
wanghuancoder authored May 17, 2024
1 parent ae1bd9a commit 3e9f0e3
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 9 deletions.
24 changes: 22 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,29 @@ pir::Value parameter(const std::string& name) {
}

void set_parameter(const pir::Value& parameter, const std::string& name) {
std::unique_ptr<pir::Parameter> param(
pir::Parameter* param = ApiBuilder::Instance().GetParameter(name);
if (param) {
PADDLE_ENFORCE_EQ(param->type(),
parameter.type(),
phi::errors::InvalidArgument(
"Duplicate parameter %s with different type.", name));
} else {
std::unique_ptr<pir::Parameter> param_new(
new pir::Parameter(nullptr, 0, parameter.type()));
ApiBuilder::Instance().SetParameter(name, std::move(param_new));
ApiBuilder::Instance().GetBuilder()->Build<pir::SetParameterOp>(parameter,
name);
}
}

void updata_parameter(const pir::Value& parameter, const std::string& name) {
pir::Parameter* param = ApiBuilder::Instance().GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(param,
phi::errors::InvalidArgument(
"Parameter %s not exist, can not updata.", name));
std::unique_ptr<pir::Parameter> param_new(
new pir::Parameter(nullptr, 0, parameter.type()));
ApiBuilder::Instance().SetParameter(name, std::move(param));
ApiBuilder::Instance().SetParameter(name, std::move(param_new));
ApiBuilder::Instance().GetBuilder()->Build<pir::SetParameterOp>(parameter,
name);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pir::Value parameter(const std::string& name);

void set_parameter(const pir::Value& parameter, const std::string& name);

void updata_parameter(const pir::Value& parameter, const std::string& name);

void shadow_output(const pir::Value& persist_value, const std::string& name);

pir::Value embedding_grad(const pir::Value& x,
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,32 @@ static PyObject *static_api_set_parameter(PyObject *self,
}
}

static PyObject *static_api_updata_parameter(PyObject *self,
PyObject *args,
PyObject *kwargs) {
try {
VLOG(6) << "Add uodata_parameter op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);

// Get Value from args
PyObject *parameter_obj = PyTuple_GET_ITEM(args, 0);
auto parameter = CastPyArg2Value(parameter_obj, "parameter", 0);

// Parse Attributes
PyObject *name_obj = PyTuple_GET_ITEM(args, 1);
std::string name = CastPyArg2String(name_obj, "name", 1);
// Call ir static api
CallStackRecorder callstack_recoder("uodata_parameter");
callstack_recoder.Record();
paddle::dialect::updata_parameter(parameter, name);
callstack_recoder.AttachToOps();
Py_RETURN_NONE;
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}

static PyObject *static_api_set_persistable_value(PyObject *self,
PyObject *args,
PyObject *kwargs) {
Expand Down Expand Up @@ -949,6 +975,10 @@ static PyMethodDef ManualOpsAPI[] = {
(PyCFunction)(void (*)(void))static_api_set_parameter,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for set_parameter."},
{"updata_parameter",
(PyCFunction)(void (*)(void))static_api_updata_parameter,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for updata_parameter."},
{"set_persistable_value",
(PyCFunction)(void (*)(void))static_api_set_persistable_value,
METH_VARARGS | METH_KEYWORDS,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _pir_transform(t, dtype):
param = op.operand(0).source()
cast_param = paddle.cast(param, dtype)
cast_param.persistable = True
paddle._pir_ops.set_parameter(cast_param, t.name)
paddle._pir_ops.updata_parameter(cast_param, t.name)
block.remove_op(op)
break
main.set_parameters_from(startup)
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/base/layer_helper_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
default_main_program,
default_startup_program,
in_dygraph_mode,
in_dynamic_or_pir_mode,
in_pir_mode,
)
from .initializer import _global_bias_initializer, _global_weight_initializer
Expand Down Expand Up @@ -377,7 +378,7 @@ def create_parameter(
else default_initializer
)
if attr.name is None:
if in_dygraph_mode():
if in_dynamic_or_pir_mode():
attr.name = unique_name.generate(".".join([self.name, suffix]))
else:
attr.name = self.main_program._name_generator.generate(
Expand Down
15 changes: 10 additions & 5 deletions test/auto_parallel/pir/test_ir_dist_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class TestBuildFakeProgram(unittest.TestCase):
def test_build_api(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand All @@ -55,7 +56,8 @@ def test_build_api(self):
def test_build_replicated_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand Down Expand Up @@ -122,7 +124,8 @@ def test_build_replicated_program(self):
def test_build_col_parallel_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input', shape=[BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
Expand Down Expand Up @@ -171,7 +174,8 @@ def test_build_col_parallel_program(self):
def test_build_row_parallel_program(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input',
Expand Down Expand Up @@ -223,7 +227,8 @@ def test_build_row_parallel_program(self):
def test_build_with_shard_tensor(self):
with paddle.pir_utils.IrGuard():
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
start_program = paddle.base.Program()
with paddle.base.program_guard(main_program, start_program):
mesh = dist.ProcessMesh([0, 1], dim_names=['mp'])
input = paddle.static.data(
name='input',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,5 @@ def test_gan_float32(self):


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit 3e9f0e3

Please sign in to comment.