Skip to content

Commit

Permalink
【PIR API adaptor No.105】 Migrate paddle.i0 into pir (PaddlePaddle#58603)
Browse files Browse the repository at this point in the history
* ✨ Refactor: enable new ir op and added new ir test

* Update python/paddle/tensor/math.py

Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com>

* ♻️ Refactor: updated test

* 🎨 Fix: updated code style

---------

Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com>
  • Loading branch information
2 people authored and zeroRains committed Nov 8, 2023
1 parent a4ec0fe commit 6872314
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -6664,7 +6664,7 @@ def i0(x, name=None):
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.i0(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i0")
Expand Down
9 changes: 6 additions & 3 deletions test/legacy_test/test_i0_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.random.seed(100)
paddle.seed(100)
Expand All @@ -40,10 +41,12 @@ class TestI0API(unittest.TestCase):

def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.out_ref = output_i0(self.x)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

@test_with_pir_api
def test_api_static(self):
def run(place):
paddle.enable_static()
Expand All @@ -58,8 +61,7 @@ def run(place):
feed={"x": self.x},
fetch_list=[out],
)
out_ref = output_i0(self.x)
np.testing.assert_allclose(res[0], out_ref, rtol=1e-5)
np.testing.assert_allclose(res[0], self.out_ref, rtol=1e-5)
paddle.disable_static()

for place in self.place:
Expand Down Expand Up @@ -130,13 +132,14 @@ def init_config(self):
self.target = output_i0(self.inputs['x'])

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[ref_i0_grad(self.case, 1 / self.case.size)],
check_pir=True,
)


Expand Down

0 comments on commit 6872314

Please sign in to comment.