Skip to content

Commit

Permalink
【PIR API adaptor No.65, 69】Migrate some ops into pir (PaddlePaddle#58698
Browse files Browse the repository at this point in the history
)
  • Loading branch information
longranger2 authored and SecretXV committed Nov 28, 2023
1 parent bf9f8b3 commit 2399d86
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
13 changes: 12 additions & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,9 @@ def eye(num_rows, num_columns=None, dtype=None, name=None):
"""

def _check_attr(attr, message):
if isinstance(attr, ((Variable, core.eager.Tensor))):
if isinstance(
attr, ((Variable, core.eager.Tensor, paddle.pir.OpResult))
):
assert len(attr.shape) == 1 and attr.shape[0] in [1, -1]
elif not isinstance(attr, int) or attr < 0:
raise TypeError(f"{message} should be a non-negative int.")
Expand Down Expand Up @@ -2198,6 +2200,15 @@ def empty_like(x, dtype=None, name=None):
)
out.stop_gradient = True
return out
elif in_pir_mode():
shape = paddle.shape(x)
out = _C_ops.empty(
shape,
convert_np_dtype_to_dtype_(dtype),
_current_expected_place(),
)
out.stop_gradient = True
return out
else:
helper = LayerHelper("empty_like", **locals())
check_variable_and_dtype(
Expand Down
33 changes: 18 additions & 15 deletions test/legacy_test/test_empty_like_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import paddle
from paddle.base import core
from paddle.base.data_feeder import convert_dtype
from paddle.static import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestEmptyLikeAPICommon(unittest.TestCase):
Expand Down Expand Up @@ -163,32 +163,33 @@ class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
train_program = Program()
startup_program = Program()
train_program = paddle.static.Program()
startup_program = paddle.static.Program()

with program_guard(train_program, startup_program):
with paddle.static.program_guard(train_program, startup_program):
x = np.random.random(self.x_shape).astype(self.dtype)
data_x = paddle.static.data(
'x', shape=self.data_x_shape, dtype=self.dtype
)

out = paddle.empty_like(data_x)

place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])
place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])

self.dst_dtype = self.dtype
self.dst_shape = x.shape
self.__check_out__(res[0])
self.dst_dtype = self.dtype
self.dst_shape = x.shape
self.__check_out__(res[0])

paddle.disable_static()
paddle.disable_static()

def init_config(self):
self.x_shape = (200, 3)
Expand All @@ -212,6 +213,7 @@ def init_config(self):
self.data_x_shape = [200, 3]
self.dtype = 'float16'

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand Down Expand Up @@ -245,6 +247,7 @@ def init_config(self):
self.data_x_shape = [200, 3]
self.dtype = 'uint16'

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand Down
13 changes: 8 additions & 5 deletions test/legacy_test/test_eye_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import core, framework
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestEyeOp(OpTest):
Expand All @@ -46,7 +47,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def init_dtype(self):
self.dtype = np.int32
Expand All @@ -69,7 +70,7 @@ def setUp(self):
self.outputs = {'Out': np.eye(50, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestEyeOp2(OpTest):
Expand All @@ -85,11 +86,12 @@ def setUp(self):
self.outputs = {'Out': np.eye(99, 1, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class API_TestTensorEye(unittest.TestCase):
def test_out(self):
@test_with_pir_api
def test_static_out(self):
with paddle.static.program_guard(paddle.static.Program()):
data = paddle.eye(10)
place = base.CPUPlace()
Expand All @@ -114,6 +116,7 @@ def test_out(self):
expected_result = np.eye(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)

def test_dynamic_out(self):
paddle.disable_static()
out = paddle.eye(10, dtype="int64")
expected_result = np.eye(10, dtype="int64")
Expand Down Expand Up @@ -215,7 +218,7 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 2399d86

Please sign in to comment.