Skip to content

Commit

Permalink
【PIR API adaptor No.148】outer (PaddlePaddle#58927)
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe authored and SecretXV committed Nov 28, 2023
1 parent b192568 commit 4c4a737
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,7 @@ def outer(x, y, name=None):
nx = x.reshape((-1, 1))
ny = y.reshape((1, -1))

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.matmul(nx, ny, False, False)
else:

Expand Down
17 changes: 12 additions & 5 deletions test/legacy_test/test_outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import numpy as np

import paddle
from paddle.static import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestMultiplyApi(unittest.TestCase):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype
Expand Down Expand Up @@ -53,7 +55,8 @@ def _run_dynamic_graph_case(self, x_data, y_data):
res = paddle.outer(x, y)
return res.numpy()

def test_multiply(self):
@test_with_pir_api
def test_multiply_static(self):
np.random.seed(7)

# test static computation graph: 3-d array
Expand Down Expand Up @@ -86,6 +89,7 @@ def test_multiply(self):
res = self._run_static_graph_case(x_data, y_data)
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)

def test_multiply_dynamic(self):
# test dynamic computation graph: 3-d array
x_data = np.random.rand(5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64)
Expand Down Expand Up @@ -138,14 +142,17 @@ def test_multiply(self):


class TestMultiplyError(unittest.TestCase):
def test_errors(self):
def test_errors_static(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, paddle.outer, x, y)

def test_errors_dynamic(self):
np.random.seed(7)

# test dynamic computation graph: dtype must be Tensor type
Expand Down

0 comments on commit 4c4a737

Please sign in to comment.