From 7eb6b0d550e064fb869f0705dc1bacf97ca8ec00 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Tue, 9 Jan 2024 14:55:14 +0800 Subject: [PATCH] [Dy2St] fix `test_grad` in PIR mode (#60621) --------- Co-authored-by: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> --- paddle/fluid/pybind/pir.cc | 138 +++++++----------- python/paddle/autograd/backward_utils.py | 8 + python/paddle/autograd/ir_backward.py | 11 +- python/paddle/decomposition/decomp.py | 6 +- .../jit/dy2static/pir_partial_program.py | 32 ++-- python/paddle/nn/clip.py | 5 +- python/paddle/pir/__init__.py | 4 +- test/dygraph_to_static/test_container.py | 11 +- test/dygraph_to_static/test_grad.py | 48 +++--- 9 files changed, 128 insertions(+), 135 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index d9d7eb3abe186..8aa14748aee2f 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -171,12 +171,22 @@ std::string GetValueInfo(Value v) { } else if (auto arg = v.dyn_cast()) { ss << "block_arg, index = " << arg.index(); } - ss << ", dtype=" << v.type(); - if (v.type().isa()) { - ss << ", place=" - << v.type() - .dyn_cast() - .place(); + if (!v.type()) { + ss << ", dtype=<>"; + } else { + ss << ", dtype=" << v.type(); + if (v.type().isa()) { + ss << ", place=" + << v.type() + .dyn_cast() + .place(); + } + } + auto stop_gradient = v.attribute(kAttrStopGradients); + if (stop_gradient && !stop_gradient.data()) { + ss << ", stop_gradient=False"; + } else { + ss << ", stop_gradient=True"; } return ss.str(); } @@ -527,13 +537,20 @@ void BindOperation(py::module *m) { return op_list; }) .def("replace_all_uses_with", - [](Operation &self, const std::vector &op_results) { - self.ReplaceAllUsesWith(op_results); + [](Operation &self, const std::vector &values) { + self.ReplaceAllUsesWith(values); }) .def("as_if_op", [](Operation &self) { return PyIfOp(self.dyn_cast()); }) .def("as_while_op", - [](Operation &self) { return PyWhileOp(self.dyn_cast()); }); + [](Operation &self) { return PyWhileOp(self.dyn_cast()); }) + .def("__repr__", [](Operation &self) { + std::ostringstream print_stream; + print_stream << "Operation("; + self.Print(print_stream); + print_stream << ")"; + return print_stream.str(); + }); py::class_ block_container( *m, "Operation_BlockContainer", R"DOC( The Operation_BlockContainer only use to walk all blocks in the operation. @@ -555,6 +572,9 @@ py::str Value2String(Value self) { } phi::DataType GetValueDtype(Value value) { + if (!value.type()) { + PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr.")); + } if (value.type().isa()) { return paddle::dialect::TransToPhiDataType( value.type().dyn_cast().dtype()); @@ -572,6 +592,9 @@ phi::DataType GetValueDtype(Value value) { } const phi::DDim &GetValueDims(Value value) { + if (!value.type()) { + PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr.")); + } if (value.type().isa()) { return value.type().dyn_cast().dims(); } else if (value.type().isa()) { @@ -768,21 +791,6 @@ void BindValue(py::module *m) { .def("first_use", &Value::first_use, return_value_policy::reference) .def("has_one_use", &Value::HasOneUse) .def("use_empty", &Value::use_empty) - .def("__str__", - [](Value self) -> py::str { - std::ostringstream print_stream; - print_stream << "Value("; - print_stream << GetValueInfo(self); - auto stop_gradient = - self.attribute(kAttrStopGradients); - if (stop_gradient && !stop_gradient.data()) { - print_stream << ", stop_gradient=False"; - } else { - print_stream << ", stop_gradient=True"; - } - print_stream << ")"; - return print_stream.str(); - }) .def("apply", &apply) .def("is_same", &Value::operator==) .def("hash", [](Value self) { return std::hash{}(self); }) @@ -965,14 +973,14 @@ using SplitedProgram = std::vector>; using SplitedAttribute = std::map>; using SplitedResult = std::pair; -pir::OpResult FakeOpResult() { - // create a fake opresults to simplify `ForwardBackwardSplit`. - return pir::OpResult(nullptr); +pir::Value FakeValue() { + // create a fake value to simplify `ForwardBackwardSplit`. + return pir::Value(nullptr); } -bool IsFakeOpResult(const pir::OpResult &result) { - // create a fake opresults to simplify `ForwardBackwardSplit`. - return result.Value::impl() == nullptr || !result.Value::type(); +bool IsFakeValue(const pir::Value &value) { + // create a fake value to simplify `ForwardBackwardSplit`. + return value.impl() == nullptr || !value.type(); } static auto GetNoNeedBufferValue(const ::pir::Block *whole_block, @@ -1032,7 +1040,7 @@ std::pair, OpResultMap> CloneProgram( } void AppendShadowOutput(Program *forward_program, - const pir::OpResult &result, + const pir::Value &value, const std::string &name, size_t start_point) { pir::IrContext *ctx = pir::IrContext::Instance(); @@ -1041,7 +1049,7 @@ void AppendShadowOutput(Program *forward_program, {"output_name", pir::StrAttribute::get(ctx, name)}, }; pir::Operation *operation = - pir::Operation::Create({result}, attribute_map, {}, op_info); + pir::Operation::Create({value}, attribute_map, {}, op_info); auto position = forward_program->block()->begin(); std::advance(position, start_point); if (position == forward_program->block()->end()) { @@ -1052,19 +1060,19 @@ void AppendShadowOutput(Program *forward_program, } int AppendShadowOutputs(Program *forward_program, - const std::vector &outputs_op_result, + const std::vector &outputs, int start_point, std::string name_prefix) { int counter = 0; - std::unordered_set added_op_result; + std::unordered_set added_value; - for (const auto &result : outputs_op_result) { - if (!added_op_result.count(result) || IsFakeOpResult(result)) { + for (const auto &value : outputs) { + if (!added_value.count(value) || IsFakeValue(value)) { std::string shadow_output_name = name_prefix + std::to_string(counter); AppendShadowOutput( - forward_program, result, shadow_output_name, start_point + counter); + forward_program, value, shadow_output_name, start_point + counter); counter += 1; - added_op_result.insert(result); + added_value.insert(value); } } // return the inserted op. @@ -1073,51 +1081,17 @@ int AppendShadowOutputs(Program *forward_program, SplitedResult SplitForwardBackward( const Program &program, - const std::vector &op_result_forward_inputs, - const std::vector &op_result_forward_params, - const std::vector &op_result_forward_outputs, - const std::vector &op_result_forward_inputs_grads, - const std::vector &op_result_forward_params_grads, - const std::vector &op_result_forward_outputs_grads, + const std::vector &forward_inputs, + const std::vector &forward_params, + const std::vector &forward_outputs, + const std::vector &forward_inputs_grads, + const std::vector &forward_params_grads, + const std::vector &forward_outputs_grads, const std::vector &forward_range, const std::vector &backward_range) { - // transform opresult -> value - std::vector forward_inputs, forward_outputs, forward_inputs_grads, - forward_outputs_grads, forward_params, forward_params_grads; - - auto op_result_to_value = [](const pir::OpResult &r) { - if (r.impl() == nullptr) return Value(nullptr); - return Value(r.Value::impl()); - }; - - std::transform(op_result_forward_inputs.begin(), - op_result_forward_inputs.end(), - std::back_inserter(forward_inputs), - op_result_to_value); - std::transform(op_result_forward_outputs.begin(), - op_result_forward_outputs.end(), - std::back_inserter(forward_outputs), - op_result_to_value); - std::transform(op_result_forward_inputs_grads.begin(), - op_result_forward_inputs_grads.end(), - std::back_inserter(forward_inputs_grads), - op_result_to_value); - std::transform(op_result_forward_outputs_grads.begin(), - op_result_forward_outputs_grads.end(), - std::back_inserter(forward_outputs_grads), - op_result_to_value); - std::transform(op_result_forward_params.begin(), - op_result_forward_params.end(), - std::back_inserter(forward_params), - op_result_to_value); - std::transform(op_result_forward_params_grads.begin(), - op_result_forward_params_grads.end(), - std::back_inserter(forward_params_grads), - op_result_to_value); - std::vector forward_in_out_values; - for (auto &v : std::vector *>( - {&forward_inputs, &forward_outputs, &forward_params})) { + for (auto &v : + std::vector({&forward_inputs, &forward_outputs, &forward_params})) { forward_in_out_values.insert( forward_in_out_values.end(), v->begin(), v->end()); } @@ -1393,8 +1367,8 @@ void BindUtils(pybind11::module *m) { m->def("reset_shadow_output_name", ResetShadowOutputName); m->def("split_program", SplitForwardBackward); m->def("append_shadow_outputs", AppendShadowOutputs); - m->def("fake_op_result", FakeOpResult); - m->def("is_fake_op_result", IsFakeOpResult); + m->def("fake_value", FakeValue); + m->def("is_fake_value", IsFakeValue); m->def("get_current_insertion_point", []() -> PyInsertionPoint { return {ApiBuilder::Instance().GetCurrentInsertionPoint()}; }); diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index feb2500cb3718..c6ee84682359a 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -111,6 +111,10 @@ def __iter__(self): def __contains__(self, key): return ValueWrapper(key) in self._items + def __repr__(self) -> str: + items_str = ", ".join(f"{key}: {val}" for key, val in self.items()) + return f'ValueDict({items_str})' + class ValueSet: def __init__( @@ -153,6 +157,10 @@ def __iter__(self): def __contains__(self, val): return ValueWrapper(val) in self._set + def __repr__(self) -> str: + items_str = ", ".join(repr(item) for item in self) + return f'ValueSet({items_str})' + class State: """ diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index aa26c86ce31a4..4cee722bfe740 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -63,6 +63,9 @@ def check_all_puts(block, inputs, outputs): def append_full_like(float_value, copy_value, value, state, backward_ops): + if paddle.pir.is_fake_value(value): + state.value_to_valuegrad[value] = [[paddle.pir.fake_value()]] + return if copy_value.is_tensorarray(): value_grad = paddle._pir_ops.create_array_like( copy_value, @@ -174,9 +177,9 @@ def prepare_grad_outputs(grad_outputs, outputs, state): visited_output.add(opresult) continue else: - if paddle.pir.is_fake_op_result(opresult): + if paddle.pir.is_fake_value(opresult): state.value_to_valuegrad[opresult] = [ - [paddle.pir.fake_op_result()] + [paddle.pir.fake_value()] ] else: grad_value = append_full_like( @@ -702,9 +705,7 @@ def append_yield( new_value = return_map_value( value, control_flow_value_to_copyvalue_map ) - value_grad = append_full_like( - 0.0, new_value, value, state, backward_ops - ) + append_full_like(0.0, new_value, value, state, backward_ops) input_grad = state.value_to_valuegrad[value][0][0] inputs_grad.append(input_grad) diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index 3a1f4c9382bcd..81c65550103b2 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -605,7 +605,7 @@ def _prepare_grad_outputs(fwd_op, bwd_op): new_grad_outputs.append(grad_outputs[index]) index += 1 else: - new_grad_outputs.append([pir.fake_op_result()]) + new_grad_outputs.append([pir.fake_value()]) return new_grad_outputs @@ -679,7 +679,7 @@ def _decomp_bwd_with_vjp( if grad_input[0] is not None and grad_input[0].initialized(): res.append(grad_input[0]) else: - res.append(pir.fake_op_result()) + res.append(pir.fake_value()) assert len(res) == len( bwd_op.results() ), "results of original backward op do not match results of decomposed backward op" @@ -752,7 +752,7 @@ def _decomp_bwd_without_vjp( res.append(new_grad_inputs[input_grads_idx]) input_grads_idx += 1 else: - res.append(pir.fake_op_result()) + res.append(pir.fake_value()) # step4: upgrade grad_var_to_var _upgrade_grad_var_to_var( diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 574821ab5b342..de15ff4689fb1 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -28,7 +28,7 @@ from paddle.base.data_feeder import check_type, convert_dtype from paddle.base.dygraph.base import switch_to_static_graph from paddle.optimizer.lr import LRScheduler -from paddle.pir import Value, fake_op_result, is_fake_op_result +from paddle.pir import Value, fake_value, is_fake_value from .utils import RETURN_NO_VALUE_MAGIC_NUM, backend_guard @@ -59,13 +59,13 @@ def _tolist(self): """ Flattens the nested sequences into single list and remove duplicate variables + non-variable elements. """ - variable_map = ValueDict() # opresult -> list idx + variable_map = ValueDict() # value -> list idx variable_list = [] for value in paddle.utils.flatten(self._raw_input): if not isinstance(value, Value): continue if value in variable_map: - # remove duplicate opresults. + # remove duplicate values. continue variable_map[value] = len(variable_list) variable_list.append(value) @@ -133,7 +133,7 @@ def get_value_name_map(self): @classmethod def _get_value_name_map_from_program(cls, program): ret = ValueDict() - ret[fake_op_result()] = "FakeVar" + ret[fake_value()] = "FakeVar" for op in program.global_block().ops: if op.name() == "builtin.set_parameter": ret[op.operand(0).source()] = op.attrs()["parameter_name"] @@ -757,7 +757,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram): forward_outputs_grads = [] for out_op_result in targets: if out_op_result.stop_gradient is True: - forward_outputs_grads.append(fake_op_result()) + forward_outputs_grads.append(fake_value()) else: value = paddle.full_like( out_op_result, @@ -791,7 +791,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram): ), grad_outputs=list( filter( - lambda x: not is_fake_op_result(x), + lambda x: not is_fake_value(x), forward_outputs_grads, ) ), @@ -810,24 +810,18 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram): # start_idx + 1, main_program, program # ) - mapping_op_result = ( - lambda x: x if isinstance(x, Value) else fake_op_result() - ) + mapping_value = lambda x: x if isinstance(x, Value) else fake_value() inputs_size = len(inputs) - x_grad_value = list( - map(mapping_op_result, grad_info_map[0:inputs_size]) - ) - p_grad_value = list(map(mapping_op_result, grad_info_map[inputs_size:])) - o_grad_value = list(map(mapping_op_result, forward_outputs_grads)) + x_grad_value = list(map(mapping_value, grad_info_map[0:inputs_size])) + p_grad_value = list(map(mapping_value, grad_info_map[inputs_size:])) + o_grad_value = list(map(mapping_value, forward_outputs_grads)) # insert grads name for RunableProgram (we need name for grad_inputs and grad_outputs) input_grads_to_append = list( - filter(lambda x: not is_fake_op_result(x), o_grad_value) + filter(lambda x: not is_fake_value(x), o_grad_value) ) output_grads_to_append = list( - filter( - lambda x: not is_fake_op_result(x), x_grad_value + p_grad_value - ) + filter(lambda x: not is_fake_value(x), x_grad_value + p_grad_value) ) backward_end_op_index = len(program.global_block().ops) paddle.base.libpaddle.pir.append_shadow_outputs( @@ -1020,7 +1014,7 @@ def _set_grad_type(self, params, train_program: RunableProgram): forward_params_grads = train_program.param_grad_values train_program = train_program.program for param, value in zip(params, forward_params_grads): - if is_fake_op_result(value): + if is_fake_value(value): continue if value.is_selected_row_type(): param._set_grad_type( diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index da85de1f9cee2..775a11fd8e398 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -678,7 +678,10 @@ def _dygraph_clip(self, params_grads): sum_square_list = [] sum_square_list_fp16 = [] sum_square_list_fp32 = [] - src_mesh = params_grads[0][0].process_mesh + if len(params_grads) > 0 and len(params_grads[0]) > 0: + src_mesh = params_grads[0][0].process_mesh + else: + src_mesh = None for p, g in params_grads: if g is None: diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 3d21328e0ceed..8be6f115ca976 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -23,9 +23,9 @@ Value, check_unregistered_ops, create_shaped_type, - fake_op_result, + fake_value, get_current_insertion_point, - is_fake_op_result, + is_fake_value, parse_program, register_paddle_dialect, reset_insertion_point_to_end, diff --git a/test/dygraph_to_static/test_container.py b/test/dygraph_to_static/test_container.py index f657562d8b62d..e4ba864516af8 100644 --- a/test/dygraph_to_static/test_container.py +++ b/test/dygraph_to_static/test_container.py @@ -20,6 +20,7 @@ from dygraph_to_static_utils import Dy2StTestBase import paddle +from paddle.framework import use_pir_api class BufferLayers(paddle.nn.Layer): @@ -90,12 +91,16 @@ def _init_seed(self): def _run(self, to_static): self._init_seed() + net = self.net if to_static: - self.net = paddle.jit.to_static(self.net) + net = paddle.jit.to_static(net) x = paddle.rand([16, 10], 'float32') - out = self.net(x) + out = net(x) + # TODO(pir-save-load): Fix this after we support save/load in PIR + if use_pir_api(): + return out if to_static: - load_out = self._test_load(self.net, x) + load_out = self._test_load(net, x) np.testing.assert_allclose( load_out, out, diff --git a/test/dygraph_to_static/test_grad.py b/test/dygraph_to_static/test_grad.py index 15876ddb3f6a4..ada44758fbf67 100644 --- a/test/dygraph_to_static/test_grad.py +++ b/test/dygraph_to_static/test_grad.py @@ -17,9 +17,13 @@ import unittest import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, enable_to_static_guard +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_legacy_and_pt_and_pir, +) import paddle +from paddle.framework import use_pir_api class GradLayer(paddle.nn.Layer): @@ -67,24 +71,20 @@ def forward(self, x): class TestGrad(Dy2StTestBase): def setUp(self): - self.func = paddle.jit.to_static(GradLayer()) + self.func = GradLayer() self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x.stop_gradient = False - def _run(self, func, to_static): - with enable_to_static_guard(to_static): - ret = func(self.x).numpy() - return ret - + @test_legacy_and_pt_and_pir def test_forward(self): - dygraph_res = self._run(self.func, to_static=False) - static_res = self._run(self.func, to_static=True) + dygraph_res = self.func(self.x).numpy() + static_res = paddle.jit.to_static(self.func)(self.x).numpy() np.testing.assert_allclose(static_res, dygraph_res, rtol=1e-05) class TestGradLinear(TestGrad): def setUp(self): - self.func = paddle.jit.to_static(GradLinearLayer()) + self.func = GradLinearLayer() self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x.stop_gradient = False @@ -99,45 +99,53 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_legacy_and_pt_and_pir def test_save_infer_program(self): - self.setUp() # make self.func change to ast mode + # TODO(pir-save-load): Fix this after we support save/load in PIR + if use_pir_api(): + return + static_fn = paddle.jit.to_static(self.func) input_spec = [ paddle.static.InputSpec(shape=[10, 2, 5], dtype='float32') ] - paddle.jit.save(self.func, self.infer_model_path, input_spec=input_spec) + paddle.jit.save(static_fn, self.infer_model_path, input_spec=input_spec) load_func = paddle.jit.load(self.infer_model_path) - origin_res = self.func(self.x).numpy() + origin_res = static_fn(self.x).numpy() load_res = load_func(self.x).numpy() np.testing.assert_allclose(origin_res, load_res, rtol=1e-05) + @test_legacy_and_pt_and_pir def test_save_train_program(self): - self.setUp() # make self.func change to ast mode + static_fn = paddle.jit.to_static(self.func) grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0) optimizer = paddle.optimizer.SGD( learning_rate=0.01, grad_clip=grad_clip, - parameters=self.func.parameters(), + parameters=static_fn.parameters(), ) for i in range(10): - out = self.func(self.x) + out = static_fn(self.x) avg_loss = paddle.mean(paddle.abs(out - 1)) avg_loss.backward() optimizer.minimize(avg_loss) - self.func.clear_gradients() + static_fn.clear_gradients() - paddle.jit.save(self.func, self.train_model_path) + # TODO(pir-save-load): Fix this after we support save/load in PIR + if use_pir_api(): + return + paddle.jit.save(static_fn, self.train_model_path) load_func = paddle.jit.load(self.train_model_path) - origin_res = self.func(self.x).numpy() + origin_res = static_fn(self.x).numpy() load_res = load_func(self.x).numpy() np.testing.assert_allclose(origin_res, load_res, rtol=1e-05) class TestNoGradLinear(TestGradLinear): def setUp(self): - self.func = paddle.jit.to_static(NoGradLinearLayer()) + self.func = NoGradLinearLayer() self.x = paddle.ones(shape=[10, 2, 5], dtype='float32') self.x.stop_gradient = False