Skip to content

Commit

Permalink
[NewIR] No.10 Migrate silu into pir (PaddlePaddle#57157)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored Sep 14, 2023
1 parent 983a2e5 commit a1d40f3
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# remove this file and support Vjp methods
# code gen.


vjp_interface_declare_gen_op_list = [
"tanh",
"mean",
Expand All @@ -44,6 +45,7 @@
'layer_norm',
'reshape',
'cast',
'silu',
]
vjp_interface_implementation_gen_op_list = [
"tanh",
Expand All @@ -67,4 +69,5 @@
'layer_norm',
'reshape',
'cast',
'silu',
]
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import paddle
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.framework import core
from paddle.framework import core, in_dynamic_or_pir_mode
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ...base.data_feeder import check_dtype, check_variable_and_dtype
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def silu(x, name=None):
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.silu(x)
else:
check_variable_and_dtype(
Expand Down
284 changes: 146 additions & 138 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
import paddle
from paddle import base
from paddle.autograd.ir_backward import grad as ir_grad
from paddle.base import core, unique_name
from paddle.base import Scope, core, unique_name
from paddle.base.backward import append_backward
from paddle.base.executor import Executor
from paddle.base.executor import Executor, scope_guard
from paddle.base.framework import (
OpProtoHolder,
Program,
Expand Down Expand Up @@ -1311,51 +1311,54 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

if len(fetch_list) == 0:
for var in result.items():
if no_check_set is not None and var in no_check_set:
continue
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
with scope_guard(Scope()):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

if len(fetch_list) == 0:
for var in result.items():
if no_check_set is not None and var in no_check_set:
continue
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])

# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)
# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)
return outs

def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
Expand Down Expand Up @@ -3430,104 +3433,109 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]

ret_tuple = self.python_api(*args)
outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
with scope_guard(Scope()):
# prepare inps attributes feed
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]

ret_tuple = self.python_api(*args)
outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
fetch_list = getattr(self, "fetch_list", [])

# cast outputs
if self.dtype == np.uint16:
for output in outputs:
outputs[output][0] = paddle.cast(
outputs[output][0],
paddle.base.core.DataType.FLOAT32,
)

# cast outputs
if self.dtype == np.uint16:
for output in outputs:
outputs[output][0] = paddle.cast(
outputs[output][0],
paddle.base.core.DataType.FLOAT32,
)
outputs_valid = outputs
loss_inputs = []
for input_name in inputs_to_check:
loss_inputs.append(inputs_dict[input_name])

outputs_valid = outputs
loss_inputs = []
for input_name in inputs_to_check:
loss_inputs.append(inputs_dict[input_name])
if user_defined_grad_outputs is None:
if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid:
loss = paddle.mean(
outputs_valid[outputs_valid_key][0]
)
else:
avg_sum = []
for cur_loss in outputs_valid:
cur_avg_loss = paddle.mean(
outputs_valid[cur_loss][0]
)
avg_sum.append(cur_avg_loss)
loss_sum = paddle.add_n(avg_sum)
loss = paddle.scale(
loss_sum, scale=1.0 / float(len(avg_sum))
)

if user_defined_grad_outputs is None:
if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid:
loss = paddle.mean(outputs_valid[outputs_valid_key][0])
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(loss_inputs),
grad_outputs=None,
)
else:
avg_sum = []
for cur_loss in outputs_valid:
cur_avg_loss = paddle.mean(outputs_valid[cur_loss][0])
avg_sum.append(cur_avg_loss)
loss_sum = paddle.add_n(avg_sum)
loss = paddle.scale(
loss_sum, scale=1.0 / float(len(avg_sum))
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=grad_outputs,
)

grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(loss_inputs),
grad_outputs=None,
fetch_list = list(grad_inputs)

# executor run
executor = paddle.static.Executor()
outs = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
else:
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=grad_outputs,
)
fetch_list = list(grad_inputs)

# executor run
executor = paddle.static.Executor()
outs = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
return outs
return outs


class OpTestTool:
Expand Down
7 changes: 5 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,15 @@ def init_dtype(self):
def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
# TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True)
else:
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestSilu_ZeroDim(TestSilu):
Expand Down

0 comments on commit a1d40f3

Please sign in to comment.