From a4a1432cb2a73a352f96ffcc559fa6018acb290c Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Sun, 10 Sep 2023 12:38:35 +0000 Subject: [PATCH 1/7] [NewIR] No.10 Migrate silu into pir --- python/paddle/nn/functional/activation.py | 5 ++--- test/legacy_test/test_activation_op.py | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index f909b03cd0036..2f88bdff9c271 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -14,7 +14,7 @@ import paddle from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode -from paddle.framework import core +from paddle.framework import core, in_dynamic_or_new_ir_mode from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ...base.data_feeder import check_dtype, check_variable_and_dtype @@ -1053,14 +1053,13 @@ def silu(x, name=None): [0.73105860, 1.76159406, 2.85772228, 3.92805505]) """ - if in_dynamic_mode(): + if in_dynamic_or_new_ir_mode(): return _C_ops.silu(x) else: check_variable_and_dtype( x, 'x', [ - 'float16', 'uint16', 'float32', 'float64', diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index a7fe2cf3f602f..09f5eedea7955 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -424,6 +424,14 @@ def init_dtype(self): self.dtype = np.complex128 +class TestSilu_NewIR(TestSilu): + def test_check_output(self): + self.check_output(check_new_ir=True) + + def test_checkout_grad(self): + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + + class TestSiluAPI(unittest.TestCase): # test paddle.nn.Silu, paddle.nn.functional.silu def setUp(self): From 808008014c7b8dac865b94914f4504e6e7a80f32 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 12 Sep 2023 13:02:05 +0000 Subject: [PATCH 2/7] fix bug --- python/paddle/nn/functional/activation.py | 4 ++-- test/legacy_test/test_activation_op.py | 13 ++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 34812828af0ea..6104a1cf755b7 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -14,7 +14,7 @@ import paddle from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode -from paddle.framework import core, in_dynamic_or_new_ir_mode +from paddle.framework import core, in_dynamic_or_pir_mode from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ...base.data_feeder import check_dtype, check_variable_and_dtype @@ -1053,7 +1053,7 @@ def silu(x, name=None): [0.73105860, 1.76159406, 2.85772228, 3.92805505]) """ - if in_dynamic_or_new_ir_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.silu(x) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f7f665f66d5ff..3e1729c68a4c1 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -458,12 +458,15 @@ def init_dtype(self): def if_enable_cinn(self): pass + def test_check_output(self): + self.check_output(check_new_ir=True) + def test_check_grad(self): # TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: self.check_grad(['X'], 'Out', check_prim=False) else: - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestSilu_ZeroDim(TestSilu): @@ -481,14 +484,6 @@ def init_dtype(self): self.dtype = np.complex128 -class TestSilu_NewIR(TestSilu): - def test_check_output(self): - self.check_output(check_new_ir=True) - - def test_checkout_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) - - class TestSiluAPI(unittest.TestCase): # test paddle.nn.Silu, paddle.nn.functional.silu def setUp(self): From e13816bac1b2513d1fc3e1837359373cead22408 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 12 Sep 2023 14:32:47 +0000 Subject: [PATCH 3/7] fix bug --- python/paddle/nn/functional/activation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6104a1cf755b7..4ba784f3b2d97 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1060,6 +1060,7 @@ def silu(x, name=None): x, 'x', [ + 'float16', 'uint16', 'float32', 'float64', From 8c448e5de1d9d716a67497cfa10946fc487ccf58 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Wed, 13 Sep 2023 05:27:43 +0000 Subject: [PATCH 4/7] update --- test/legacy_test/test_activation_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 3e1729c68a4c1..a4a6a5e6c51ad 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -464,7 +464,7 @@ def test_check_output(self): def test_check_grad(self): # TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False) + self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True) else: self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) From aab3cd7de7f2dd934ab9a1156c139ee724d2ce86 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 13 Sep 2023 07:38:20 +0000 Subject: [PATCH 5/7] Add silu_grad --- .../fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 0ea367ea92fc8..004abfc736447 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -44,6 +44,7 @@ 'layer_norm', 'reshape', 'cast', + 'silu' ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -67,4 +68,5 @@ 'layer_norm', 'reshape', 'cast', + 'silu' ] From 1ecb040c16c58c7b3da000cdc4c9332c514036b3 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 13 Sep 2023 12:29:38 +0000 Subject: [PATCH 6/7] Use new scope in PIR --- test/legacy_test/op_test.py | 284 ++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 138 deletions(-) diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 69f70ee8302b3..ac6c43bab7cdb 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -40,9 +40,9 @@ import paddle from paddle import base from paddle.autograd.ir_backward import grad as ir_grad -from paddle.base import core, unique_name +from paddle.base import Scope, core, unique_name from paddle.base.backward import append_backward -from paddle.base.executor import Executor +from paddle.base.executor import Executor, scope_guard from paddle.base.framework import ( OpProtoHolder, Program, @@ -1311,51 +1311,54 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): kernel_sig = self.get_kernel_signature(place) ir_program = paddle.static.Program() with paddle.static.program_guard(ir_program): - # prepare inps attributes feed - ( - static_inputs, - attrs, - input_dict, - feed, - ) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True) - # prepare args - args = OpTestUtils.prepare_python_api_arguments( - self.python_api, - static_inputs, - attrs, - kernel_sig, - ) - inputs_sig, attrs_sig, outputs_sig = kernel_sig - args = OpTestUtils.assumption_assert_and_transform( - args, len(inputs_sig) - ) - ret_tuple = self.python_api(*args) - result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) - if hasattr(self, "python_out_sig_sub_name"): - for key in self.python_out_sig_sub_name.keys(): - for i in range(len(self.python_out_sig_sub_name[key])): - result[key][0][i].name = self.python_out_sig_sub_name[ - key - ][i] - fetch_list = getattr(self, "fetch_list", []) - # if the fetch_list is customized by user, we use it directly. - # if not, fill the fetch_list by the user configured outputs in test. - - if len(fetch_list) == 0: - for var in result.items(): - if no_check_set is not None and var in no_check_set: - continue - if isinstance(var[1], list): - for v in var[1]: - fetch_list.append(v) - else: - fetch_list.append(var[1]) + with scope_guard(Scope()): + # prepare inps attributes feed + ( + static_inputs, + attrs, + input_dict, + feed, + ) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True) + # prepare args + args = OpTestUtils.prepare_python_api_arguments( + self.python_api, + static_inputs, + attrs, + kernel_sig, + ) + inputs_sig, attrs_sig, outputs_sig = kernel_sig + args = OpTestUtils.assumption_assert_and_transform( + args, len(inputs_sig) + ) + ret_tuple = self.python_api(*args) + result = construct_output_dict_by_kernel_sig( + ret_tuple, outputs_sig + ) + if hasattr(self, "python_out_sig_sub_name"): + for key in self.python_out_sig_sub_name.keys(): + for i in range(len(self.python_out_sig_sub_name[key])): + result[key][0][ + i + ].name = self.python_out_sig_sub_name[key][i] + fetch_list = getattr(self, "fetch_list", []) + # if the fetch_list is customized by user, we use it directly. + # if not, fill the fetch_list by the user configured outputs in test. + + if len(fetch_list) == 0: + for var in result.items(): + if no_check_set is not None and var in no_check_set: + continue + if isinstance(var[1], list): + for v in var[1]: + fetch_list.append(v) + else: + fetch_list.append(var[1]) - # executor run - executor = Executor(place) - (outs,) = executor.run( - ir_program, feed=feed, fetch_list=[fetch_list] - ) + # executor run + executor = Executor(place) + (outs,) = executor.run( + ir_program, feed=feed, fetch_list=[fetch_list] + ) return outs def _check_ir_output(self, place, program, feed_map, fetch_list, outs): @@ -3430,104 +3433,109 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): kernel_sig = self.get_kernel_signature(place) ir_program = paddle.static.Program() with paddle.static.program_guard(ir_program): - # prepare inps attributes feed - ( - static_inputs, - attrs, - inputs_dict, - feed, - ) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False) - # prepare args - args = OpTestUtils.prepare_python_api_arguments( - self.python_api, - static_inputs, - attrs, - kernel_sig, - ) - inputs_sig, attrs_sig, outputs_sig = kernel_sig - args = OpTestUtils.assumption_assert_and_transform( - args, len(inputs_sig) - ) - grad_outputs = [] - if user_defined_grad_outputs is not None: - # user_defined_grad_outputs here are numpy arrays - if not isinstance(user_defined_grad_outputs, list): - user_defined_grad_outputs = [user_defined_grad_outputs] - for grad_out_value, idx in zip( - user_defined_grad_outputs, - range(len(user_defined_grad_outputs)), - ): - grad_val = paddle.static.data( - name='val_grad_%s' % idx, - shape=grad_out_value.shape, - dtype=grad_out_value.dtype, - ) - grad_outputs.append(grad_val) - feed.update({'val_grad_%s' % idx: grad_out_value}) - # delete the inputs which no need to calculate grad - for no_grad_val in no_grad_set: - del static_inputs[no_grad_val] - - ret_tuple = self.python_api(*args) - outputs = construct_output_dict_by_kernel_sig( - ret_tuple, outputs_sig - ) - if hasattr(self, "python_out_sig_sub_name"): - for key in self.python_out_sig_sub_name.keys(): - for i in range(len(self.python_out_sig_sub_name[key])): - outputs[key][0][i].name = self.python_out_sig_sub_name[ - key - ][i] - fetch_list = getattr(self, "fetch_list", []) + with scope_guard(Scope()): + # prepare inps attributes feed + ( + static_inputs, + attrs, + inputs_dict, + feed, + ) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False) + # prepare args + args = OpTestUtils.prepare_python_api_arguments( + self.python_api, + static_inputs, + attrs, + kernel_sig, + ) + inputs_sig, attrs_sig, outputs_sig = kernel_sig + args = OpTestUtils.assumption_assert_and_transform( + args, len(inputs_sig) + ) + grad_outputs = [] + if user_defined_grad_outputs is not None: + # user_defined_grad_outputs here are numpy arrays + if not isinstance(user_defined_grad_outputs, list): + user_defined_grad_outputs = [user_defined_grad_outputs] + for grad_out_value, idx in zip( + user_defined_grad_outputs, + range(len(user_defined_grad_outputs)), + ): + grad_val = paddle.static.data( + name='val_grad_%s' % idx, + shape=grad_out_value.shape, + dtype=grad_out_value.dtype, + ) + grad_outputs.append(grad_val) + feed.update({'val_grad_%s' % idx: grad_out_value}) + # delete the inputs which no need to calculate grad + for no_grad_val in no_grad_set: + del static_inputs[no_grad_val] + + ret_tuple = self.python_api(*args) + outputs = construct_output_dict_by_kernel_sig( + ret_tuple, outputs_sig + ) + if hasattr(self, "python_out_sig_sub_name"): + for key in self.python_out_sig_sub_name.keys(): + for i in range(len(self.python_out_sig_sub_name[key])): + outputs[key][0][ + i + ].name = self.python_out_sig_sub_name[key][i] + fetch_list = getattr(self, "fetch_list", []) + + # cast outputs + if self.dtype == np.uint16: + for output in outputs: + outputs[output][0] = paddle.cast( + outputs[output][0], + paddle.base.core.DataType.FLOAT32, + ) - # cast outputs - if self.dtype == np.uint16: - for output in outputs: - outputs[output][0] = paddle.cast( - outputs[output][0], - paddle.base.core.DataType.FLOAT32, - ) + outputs_valid = outputs + loss_inputs = [] + for input_name in inputs_to_check: + loss_inputs.append(inputs_dict[input_name]) - outputs_valid = outputs - loss_inputs = [] - for input_name in inputs_to_check: - loss_inputs.append(inputs_dict[input_name]) + if user_defined_grad_outputs is None: + if len(outputs_valid) == 1: + for outputs_valid_key in outputs_valid: + loss = paddle.mean( + outputs_valid[outputs_valid_key][0] + ) + else: + avg_sum = [] + for cur_loss in outputs_valid: + cur_avg_loss = paddle.mean( + outputs_valid[cur_loss][0] + ) + avg_sum.append(cur_avg_loss) + loss_sum = paddle.add_n(avg_sum) + loss = paddle.scale( + loss_sum, scale=1.0 / float(len(avg_sum)) + ) - if user_defined_grad_outputs is None: - if len(outputs_valid) == 1: - for outputs_valid_key in outputs_valid: - loss = paddle.mean(outputs_valid[outputs_valid_key][0]) + grad_inputs = ir_grad( + outputs=paddle.utils.flatten(loss), + inputs=paddle.utils.flatten(loss_inputs), + grad_outputs=None, + ) else: - avg_sum = [] - for cur_loss in outputs_valid: - cur_avg_loss = paddle.mean(outputs_valid[cur_loss][0]) - avg_sum.append(cur_avg_loss) - loss_sum = paddle.add_n(avg_sum) - loss = paddle.scale( - loss_sum, scale=1.0 / float(len(avg_sum)) + grad_inputs = ir_grad( + outputs=paddle.utils.flatten(outputs), + inputs=paddle.utils.flatten(static_inputs), + grad_outputs=grad_outputs, ) - - grad_inputs = ir_grad( - outputs=paddle.utils.flatten(loss), - inputs=paddle.utils.flatten(loss_inputs), - grad_outputs=None, + fetch_list = list(grad_inputs) + + # executor run + executor = paddle.static.Executor() + outs = executor.run( + ir_program, + feed=feed, + fetch_list=fetch_list, ) - else: - grad_inputs = ir_grad( - outputs=paddle.utils.flatten(outputs), - inputs=paddle.utils.flatten(static_inputs), - grad_outputs=grad_outputs, - ) - fetch_list = list(grad_inputs) - - # executor run - executor = paddle.static.Executor() - outs = executor.run( - ir_program, - feed=feed, - fetch_list=fetch_list, - ) - return outs + return outs class OpTestTool: From 103b4a1c9a59403b21fe35f0de7f3c33faa0f57f Mon Sep 17 00:00:00 2001 From: 0x45f Date: Wed, 13 Sep 2023 12:50:08 +0000 Subject: [PATCH 7/7] Format code --- .../pir/dialect/op_generator/vjp_interface_gen_op_list.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 004abfc736447..02dbe01992860 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -22,6 +22,7 @@ # remove this file and support Vjp methods # code gen. + vjp_interface_declare_gen_op_list = [ "tanh", "mean", @@ -44,7 +45,7 @@ 'layer_norm', 'reshape', 'cast', - 'silu' + 'silu', ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -68,5 +69,5 @@ 'layer_norm', 'reshape', 'cast', - 'silu' + 'silu', ]