Skip to content

Commit

Permalink
add diagonal_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGuge committed Nov 4, 2023
1 parent 348eabe commit 99ac6ee
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
masked_fill_,
index_fill,
index_fill_,
diagonal_scatter,
)

from .tensor.math import ( # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
from .manipulation import masked_fill_ # noqa: F401
from .manipulation import index_fill # noqa: F401
from .manipulation import index_fill_ # noqa: F401
from .manipulation import diagonal_scatter # noqa: F401
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
Expand Down Expand Up @@ -729,6 +730,7 @@
'normal_',
'index_fill',
'index_fill_',
'diagonal_scatter',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
134 changes: 134 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5575,3 +5575,137 @@ def index_fill_(x, index, axis, value, name=None):
"""
return _index_fill_impl(x, index, axis, value, True)


def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None):
"""
Embed the values of Tensor ``y`` into Tensor ``x`` along the diagonal elements
of Tensor ``x``, with respect to ``axis1`` and ``axis2``.
This function returns a tensor with fresh storage.
The argument ``offset`` controls which diagonal to consider:
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is above the main diagonal.
If ``offset`` < 0, it is below the main diagonal.
Note:
``y`` should have the same shape as paddle.diagonal(x, offset, axis1, axis2).
Args:
x (Tensor): `x`` is the original Tensor. Must be at least 2-dimensional.
y (Tensor): ``y`` is the Tensor to embed into ``x``
offset (int,optional): which diagonal to consider. Default: 0 (main diagonal).
axis1 (int,optional): first axis with respect to which to take diagonal. Default: 0.
axis2 (int,optional): second axis with respect to which to take diagonal. Default: 1.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, Tensor with diagonal embedeed with y.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(6.0).reshape((2, 3))
y = paddle.ones((2,))
nx = x.diagonal_scatter(y)
print(nx.tolist()) #[[1.0, 1.0, 2.0], [3.0, 1.0, 5.0]]
"""
x_shape = x.shape
assert (
len(x_shape) >= 2
), "Tensor x must be at least 2-dimensional in diagonal_scatter"
assert axis1 < len(x_shape) and axis1 >= -len(
x_shape
), "axis1 is out of range in diagonal_scatter (expected to be in range of [-{}, {}), but got {})".format(
len(x_shape), len(x_shape), axis1
)
assert axis2 < len(x_shape) and axis2 >= -len(
x_shape
), "axis2 is out of range in diagonal_scatter (expected to be in range of [-{}, {}), but got {})".format(
len(x_shape), len(x_shape), axis2
)

axis1 %= len(x_shape)
axis2 %= len(x_shape)
assert (
axis1 != axis2
), "axis1 and axis2 should not be identical in diagonal_scatter, but received axis1 = {}, axis2 = {}".format(
axis1, axis2
)

predshape = []
for i in range(len(x_shape)):
if i != axis1 and i != axis2:
predshape.append(x_shape[i])
diaglen = min(
x_shape[axis1],
x_shape[axis1] + offset,
x_shape[axis2],
x_shape[axis2] - offset,
)
predshape.append(diaglen)
assert tuple(predshape) == tuple(
y.shape
), f"y.shape should be {tuple(predshape)}, but received {tuple(y.shape)}"
if len(y.shape) == 1:
y = y.reshape([1, -1])

if in_dynamic_mode():
return _C_ops.fill_diagonal_tensor(x, y, offset, axis1, axis2)
else:
helper = LayerHelper('diagonal_scatter', **locals())
check_variable_and_dtype(
x,
'X',
[
'float16',
'float32',
'float64',
'bfloat16',
'uint8',
'int8',
'int32',
'int64',
'bool',
'complex64',
'complex128',
],
'paddle.tensor.manipulation.diagonal_scatter',
)
check_variable_and_dtype(
y,
'Y',
[
'float16',
'float32',
'float64',
'bfloat16',
'uint8',
'int8',
'int32',
'int64',
'bool',
'complex64',
'complex128',
],
'paddle.tensor.manipulation.diagonal_scatter',
)
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='fill_diagonal_tensor',
inputs={
'X': x,
'Y': y,
},
outputs={'Out': out},
attrs={
'offset': offset,
'axis1': axis1,
'axis2': axis2,
},
)
return out

0 comments on commit 99ac6ee

Please sign in to comment.