Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR] No.10 Migrate silu into pir #57157

Merged
merged 11 commits into from
Sep 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# remove this file and support Vjp methods
# code gen.


vjp_interface_declare_gen_op_list = [
"tanh",
"mean",
Expand All @@ -44,6 +45,7 @@
'layer_norm',
'reshape',
'cast',
'silu',
]
vjp_interface_implementation_gen_op_list = [
"tanh",
Expand All @@ -67,4 +69,5 @@
'layer_norm',
'reshape',
'cast',
'silu',
]
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import paddle
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.framework import core
from paddle.framework import core, in_dynamic_or_pir_mode
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ...base.data_feeder import check_dtype, check_variable_and_dtype
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def silu(x, name=None):
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.silu(x)
else:
check_variable_and_dtype(
Expand Down
284 changes: 146 additions & 138 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
import paddle
from paddle import base
from paddle.autograd.ir_backward import grad as ir_grad
from paddle.base import core, unique_name
from paddle.base import Scope, core, unique_name
from paddle.base.backward import append_backward
from paddle.base.executor import Executor
from paddle.base.executor import Executor, scope_guard
from paddle.base.framework import (
OpProtoHolder,
Program,
Expand Down Expand Up @@ -1311,51 +1311,54 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

if len(fetch_list) == 0:
for var in result.items():
if no_check_set is not None and var in no_check_set:
continue
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
with scope_guard(Scope()):
# prepare inps attributes feed
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=True)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

if len(fetch_list) == 0:
for var in result.items():
if no_check_set is not None and var in no_check_set:
continue
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])

# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)
# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)
return outs

def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
Expand Down Expand Up @@ -3430,104 +3433,109 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
kernel_sig = self.get_kernel_signature(place)
ir_program = paddle.static.Program()
with paddle.static.program_guard(ir_program):
# prepare inps attributes feed
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]

ret_tuple = self.python_api(*args)
outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
with scope_guard(Scope()):
# prepare inps attributes feed
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
static_inputs,
attrs,
kernel_sig,
)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]

ret_tuple = self.python_api(*args)
outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
fetch_list = getattr(self, "fetch_list", [])

# cast outputs
if self.dtype == np.uint16:
for output in outputs:
outputs[output][0] = paddle.cast(
outputs[output][0],
paddle.base.core.DataType.FLOAT32,
)

# cast outputs
if self.dtype == np.uint16:
for output in outputs:
outputs[output][0] = paddle.cast(
outputs[output][0],
paddle.base.core.DataType.FLOAT32,
)
outputs_valid = outputs
loss_inputs = []
for input_name in inputs_to_check:
loss_inputs.append(inputs_dict[input_name])

outputs_valid = outputs
loss_inputs = []
for input_name in inputs_to_check:
loss_inputs.append(inputs_dict[input_name])
if user_defined_grad_outputs is None:
if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid:
loss = paddle.mean(
outputs_valid[outputs_valid_key][0]
)
else:
avg_sum = []
for cur_loss in outputs_valid:
cur_avg_loss = paddle.mean(
outputs_valid[cur_loss][0]
)
avg_sum.append(cur_avg_loss)
loss_sum = paddle.add_n(avg_sum)
loss = paddle.scale(
loss_sum, scale=1.0 / float(len(avg_sum))
)

if user_defined_grad_outputs is None:
if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid:
loss = paddle.mean(outputs_valid[outputs_valid_key][0])
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(loss_inputs),
grad_outputs=None,
)
else:
avg_sum = []
for cur_loss in outputs_valid:
cur_avg_loss = paddle.mean(outputs_valid[cur_loss][0])
avg_sum.append(cur_avg_loss)
loss_sum = paddle.add_n(avg_sum)
loss = paddle.scale(
loss_sum, scale=1.0 / float(len(avg_sum))
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=grad_outputs,
)

grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(loss_inputs),
grad_outputs=None,
fetch_list = list(grad_inputs)

# executor run
executor = paddle.static.Executor()
outs = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
else:
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
grad_outputs=grad_outputs,
)
fetch_list = list(grad_inputs)

# executor run
executor = paddle.static.Executor()
outs = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
)
return outs
return outs


class OpTestTool:
Expand Down
7 changes: 5 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,15 @@ def init_dtype(self):
def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
# TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True)
else:
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestSilu_ZeroDim(TestSilu):
Expand Down