Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR] No.14 Migrate paddle.pow into pir #57297

Merged
merged 5 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,17 @@ def pow(x, y, name=None):
[1., 4., 9.])

"""

# in dynamic graph mode
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if isinstance(y, (int, float)):
return _C_ops.pow(x, y)
elif isinstance(y, (paddle.Tensor, Variable)):
elif isinstance(y, (paddle.Tensor, Variable, paddle.ir.OpResult)):
return _C_ops.elementwise_pow(x, y)
else:
raise TypeError(
'y must be scalar or tensor type, but received: %s ' % (y.dtype)
'y must be scalar , Tensor(in dygraph mode), OpResult(in pir mode) but received: %s '
% (y.dtype)
)
else:
# in static graph mode
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3534,12 +3534,12 @@ def if_enable_cinn(self):
pass

def test_check_output(self):
self.check_output(check_prim=True)
self.check_output(check_prim=True, check_new_ir=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestPow_ZeroDim(TestPow):
Expand Down
17 changes: 9 additions & 8 deletions test/legacy_test/test_elementwise_pow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,17 @@ def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_dygraph=False)
else:
self.check_output()
self.check_output(check_new_ir=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改了单测基类,可能对于复数类型导出的错误,可以以屏蔽下复数类型单测,看下单测是否能过?


def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(
['X', 'Y'], 'Out', check_prim=True, check_dygraph=False
)
else:
self.check_grad(['X', 'Y'], 'Out', check_prim=True)
self.check_grad(
['X', 'Y'], 'Out', check_prim=True, check_new_ir=True
)


class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp):
Expand Down Expand Up @@ -196,7 +198,7 @@ def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_dygraph=False)
else:
self.check_output()
self.check_output(check_new_ir=True)


class TestElementwisePowGradOpInt(unittest.TestCase):
Expand Down Expand Up @@ -252,7 +254,7 @@ def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_dygraph=False)
else:
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(
Expand All @@ -262,6 +264,7 @@ def test_check_grad(self):
self.inputs['X'], self.inputs['Y'], 1 / self.inputs['X'].size
),
check_prim=True,
check_new_ir=True,
)


Expand All @@ -287,10 +290,7 @@ def setUp(self):
self.outputs = {'Out': convert_float_to_uint16(out)}

def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output()
else:
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
Expand All @@ -301,6 +301,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
only_check_prim=True,
check_new_ir=True,
)


Expand Down