From d52efaad80e71853d528bacff496c8923620ac71 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 26 Jul 2022 11:05:53 +0800 Subject: [PATCH] [PHI]Add yaml and unittest for bmm op --- paddle/phi/api/yaml/legacy_api.yaml | 10 ++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 9 +++++++++ python/paddle/fluid/tests/unittests/test_bmm_op.py | 5 +++-- python/paddle/tensor/linalg.py | 3 +++ 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a7d8f5b33889e..a077c511054c1 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -326,6 +326,16 @@ kernel : func : bitwise_xor +# bmm +- api : bmm + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : BmmInferMeta + kernel : + func : bmm + backward : bmm_grad + # brelu - api : brelu args : (Tensor x, float t_min, float t_max) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 65952fc6806a3..11489c5f6cdbb 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -260,6 +260,15 @@ kernel : func : bilinear_tensor_product_grad +- backward_api : bmm_grad + forward : bmm (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : BmmGradInferMeta + kernel : + func : bmm_grad + - backward_api : brelu_grad forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out) args : (Tensor x, Tensor out_grad, float t_min, float t_max) diff --git a/python/paddle/fluid/tests/unittests/test_bmm_op.py b/python/paddle/fluid/tests/unittests/test_bmm_op.py index b9a5853c492f5..5e5c41ae88279 100644 --- a/python/paddle/fluid/tests/unittests/test_bmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_bmm_op.py @@ -27,6 +27,7 @@ class TestBmmOp(OpTest): def setUp(self): self.op_type = "bmm" + self.python_api = paddle.tensor.bmm X = np.random.random((10, 3, 4)).astype("float64") Y = np.random.random((10, 4, 5)).astype("float64") self.inputs = {'X': X, 'Y': Y} @@ -34,10 +35,10 @@ def setUp(self): self.outputs = {'Out': Out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_checkout_grad(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) class API_TestBmm(unittest.TestCase): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7e7f95d17a38f..e7091ff329062 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1521,6 +1521,9 @@ def bmm(x, y, name=None): "x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}" .format(x_shape, y_shape)) + if in_dygraph_mode(): + return _C_ops.final_state_bmm(x, y) + if paddle.in_dynamic_mode(): return _C_ops.bmm(x, y)