Skip to content

Commit 76fe32b

Browse files
committed
bmm api sink into C++
1 parent 9920f83 commit 76fe32b

File tree

4 files changed

+66
-79
lines changed

4 files changed

+66
-79
lines changed

paddle/phi/ops/yaml/ops.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,12 @@
778778

779779
- op : bmm
780780
args : (Tensor x, Tensor y)
781-
output : Tensor
781+
python_api :
782+
name : [paddle.bmm, paddle.Tensor.bmm]
783+
args_alias:
784+
x : [input]
785+
y : [mat2]
786+
output : Tensor(out)
782787
infer_meta :
783788
func : BmmInferMeta
784789
kernel :

python/paddle/_paddle_docs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,56 @@ def amax(
326326
) -> Tensor
327327
""",
328328
)
329+
330+
331+
add_doc_and_signature(
332+
"bmm",
333+
"""
334+
Applies batched matrix multiplication to two tensors.
335+
336+
Both of the two input tensors must be three-dimensional and share the same batch size.
337+
338+
If x is a (b, m, k) tensor, y is a (b, k, n) tensor, the output will be a (b, m, n) tensor.
339+
340+
Args:
341+
x (Tensor): The input Tensor.
342+
y (Tensor): The input Tensor.
343+
name (str|None): A name for this layer(optional). If set None, the layer
344+
will be named automatically. Default: None.
345+
out(Tensor, optional): The output tensor.
346+
347+
Returns:
348+
Tensor: The product Tensor.
349+
350+
Examples:
351+
.. code-block:: python
352+
353+
>>> import paddle
354+
355+
>>> # In imperative mode:
356+
>>> # size x: (2, 2, 3) and y: (2, 3, 2)
357+
>>> x = paddle.to_tensor([[[1.0, 1.0, 1.0],
358+
... [2.0, 2.0, 2.0]],
359+
... [[3.0, 3.0, 3.0],
360+
... [4.0, 4.0, 4.0]]])
361+
>>> y = paddle.to_tensor([[[1.0, 1.0],[2.0, 2.0],[3.0, 3.0]],
362+
... [[4.0, 4.0],[5.0, 5.0],[6.0, 6.0]]])
363+
>>> out = paddle.bmm(x, y)
364+
>>> print(out)
365+
Tensor(shape=[2, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
366+
[[[6. , 6. ],
367+
[12., 12.]],
368+
[[45., 45.],
369+
[60., 60.]]])
370+
371+
""",
372+
"""
373+
def bmm(
374+
x: Tensor,
375+
y: Tensor,
376+
name: str | None = None,
377+
*,
378+
out: Tensor | None = None,
379+
) -> Tensor
380+
""",
381+
)

python/paddle/tensor/linalg.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import paddle
2323
from paddle import _C_ops
24+
from paddle._C_ops import bmm # noqa: F401
2425
from paddle.base.libpaddle import DataType
2526
from paddle.common_ops_import import VarDesc
2627
from paddle.tensor.math import broadcast_shape
@@ -2539,78 +2540,6 @@ def matrix_rank(
25392540
return out
25402541

25412542

2542-
@ParamAliasDecorator({"x": ["input"], "y": ["mat2"]})
2543-
def bmm(
2544-
x: Tensor,
2545-
y: Tensor,
2546-
*,
2547-
out: paddle.Tensor | None = None,
2548-
name: str | None = None,
2549-
) -> Tensor:
2550-
"""
2551-
Applies batched matrix multiplication to two tensors.
2552-
2553-
Both of the two input tensors must be three-dimensional and share the same batch size.
2554-
2555-
If x is a (b, m, k) tensor, y is a (b, k, n) tensor, the output will be a (b, m, n) tensor.
2556-
2557-
Args:
2558-
x (Tensor): The input Tensor.
2559-
y (Tensor): The input Tensor.
2560-
out(Tensor, optional): The output tensor.
2561-
name (str|None): A name for this layer(optional). If set None, the layer
2562-
will be named automatically. Default: None.
2563-
2564-
Returns:
2565-
Tensor: The product Tensor.
2566-
2567-
Examples:
2568-
.. code-block:: python
2569-
2570-
>>> import paddle
2571-
2572-
>>> # In imperative mode:
2573-
>>> # size x: (2, 2, 3) and y: (2, 3, 2)
2574-
>>> x = paddle.to_tensor([[[1.0, 1.0, 1.0],
2575-
... [2.0, 2.0, 2.0]],
2576-
... [[3.0, 3.0, 3.0],
2577-
... [4.0, 4.0, 4.0]]])
2578-
>>> y = paddle.to_tensor([[[1.0, 1.0],[2.0, 2.0],[3.0, 3.0]],
2579-
... [[4.0, 4.0],[5.0, 5.0],[6.0, 6.0]]])
2580-
>>> out = paddle.bmm(x, y)
2581-
>>> print(out)
2582-
Tensor(shape=[2, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
2583-
[[[6. , 6. ],
2584-
[12., 12.]],
2585-
[[45., 45.],
2586-
[60., 60.]]])
2587-
2588-
"""
2589-
if in_dynamic_or_pir_mode():
2590-
return _C_ops.bmm(x, y, out=out)
2591-
else:
2592-
x_shape = x.shape
2593-
y_shape = y.shape
2594-
if not len(x_shape) == len(y_shape) == 3:
2595-
raise ValueError(
2596-
f"x and y should be 3-dimensional. But received x's dimension: {x_shape}, y's dimension: {y_shape}"
2597-
)
2598-
if x_shape[2] != -1 and y_shape[1] != -1 and x_shape[2] != y_shape[1]:
2599-
raise ValueError(
2600-
f"x's width must be equal with y's height. But received x's shape: {x_shape}, y's shape: {y_shape}"
2601-
)
2602-
if x_shape[0] != -1 and y_shape[0] != -1 and x_shape[0] != y_shape[0]:
2603-
raise ValueError(
2604-
f"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {x_shape}, y's shape: {y_shape}"
2605-
)
2606-
helper = LayerHelper('bmm', **locals())
2607-
out = helper.create_variable_for_type_inference(dtype=x.dtype)
2608-
helper.append_op(
2609-
type='bmm', inputs={'X': x, 'Y': y}, outputs={'Out': out}
2610-
)
2611-
return out
2612-
2613-
26142543
def histogram(
26152544
input: Tensor,
26162545
bins: int = 100,

test/legacy_test/test_bmm_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class TestBmmOp(OpTest):
2626
def setUp(self):
2727
self.op_type = "bmm"
2828
self.prim_op_type = "comp"
29-
self.python_api = paddle.tensor.bmm
30-
self.public_python_api = paddle.tensor.bmm
29+
self.python_api = paddle.Tensor.bmm
30+
self.public_python_api = paddle.Tensor.bmm
3131
X = np.random.random((10, 3, 4)).astype("float64")
3232
Y = np.random.random((10, 4, 5)).astype("float64")
3333
self.inputs = {'X': X, 'Y': Y}
@@ -46,8 +46,8 @@ def setUp(self):
4646
self.op_type = "bmm"
4747
self.prim_op_type = "comp"
4848
self.dtype = np.float16
49-
self.python_api = paddle.tensor.bmm
50-
self.public_python_api = paddle.tensor.bmm
49+
self.python_api = paddle.Tensor.bmm
50+
self.public_python_api = paddle.Tensor.bmm
5151
X = np.random.random((10, 3, 4)).astype("float16")
5252
Y = np.random.random((10, 4, 5)).astype("float16")
5353
self.inputs = {'X': X, 'Y': Y}
@@ -71,8 +71,8 @@ def setUp(self):
7171
self.op_type = "bmm"
7272
self.prim_op_type = "comp"
7373
self.dtype = np.uint16
74-
self.python_api = paddle.tensor.bmm
75-
self.public_python_api = paddle.tensor.bmm
74+
self.python_api = paddle.Tensor.bmm
75+
self.public_python_api = paddle.Tensor.bmm
7676
X = np.random.random((10, 3, 4)).astype("float32")
7777
Y = np.random.random((10, 4, 5)).astype("float32")
7878
self.inputs = {'X': X, 'Y': Y}

0 commit comments

Comments
 (0)