From 99ac6ee12acbcb8ef29e70a7c21dbc2475d61517 Mon Sep 17 00:00:00 2001 From: wucc <77946882+DanGuge@users.noreply.github.com> Date: Sat, 4 Nov 2023 22:21:43 +0800 Subject: [PATCH] add diagonal_scatter --- python/paddle/__init__.py | 1 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/manipulation.py | 134 +++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9c6484a1d46117..a7b440abe4fade 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -256,6 +256,7 @@ masked_fill_, index_fill, index_fill_, + diagonal_scatter, ) from .tensor.math import ( # noqa: F401 diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 84c28ce58dca89..d6ff86403613c4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -170,6 +170,7 @@ from .manipulation import masked_fill_ # noqa: F401 from .manipulation import index_fill # noqa: F401 from .manipulation import index_fill_ # noqa: F401 +from .manipulation import diagonal_scatter # noqa: F401 from .math import abs # noqa: F401 from .math import abs_ # noqa: F401 from .math import acos # noqa: F401 @@ -729,6 +730,7 @@ 'normal_', 'index_fill', 'index_fill_', + 'diagonal_scatter', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 40856399238ae2..5d14b18ba7963e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5575,3 +5575,137 @@ def index_fill_(x, index, axis, value, name=None): """ return _index_fill_impl(x, index, axis, value, True) + + +def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None): + """ + Embed the values of Tensor ``y`` into Tensor ``x`` along the diagonal elements + of Tensor ``x``, with respect to ``axis1`` and ``axis2``. + + This function returns a tensor with fresh storage. + + The argument ``offset`` controls which diagonal to consider: + If ``offset`` = 0, it is the main diagonal. + If ``offset`` > 0, it is above the main diagonal. + If ``offset`` < 0, it is below the main diagonal. + + Note: + ``y`` should have the same shape as paddle.diagonal(x, offset, axis1, axis2). + + Args: + x (Tensor): `x`` is the original Tensor. Must be at least 2-dimensional. + y (Tensor): ``y`` is the Tensor to embed into ``x`` + offset (int,optional): which diagonal to consider. Default: 0 (main diagonal). + axis1 (int,optional): first axis with respect to which to take diagonal. Default: 0. + axis2 (int,optional): second axis with respect to which to take diagonal. Default: 1. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, Tensor with diagonal embedeed with y. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.arange(6.0).reshape((2, 3)) + y = paddle.ones((2,)) + nx = x.diagonal_scatter(y) + print(nx.tolist()) #[[1.0, 1.0, 2.0], [3.0, 1.0, 5.0]] + + """ + x_shape = x.shape + assert ( + len(x_shape) >= 2 + ), "Tensor x must be at least 2-dimensional in diagonal_scatter" + assert axis1 < len(x_shape) and axis1 >= -len( + x_shape + ), "axis1 is out of range in diagonal_scatter (expected to be in range of [-{}, {}), but got {})".format( + len(x_shape), len(x_shape), axis1 + ) + assert axis2 < len(x_shape) and axis2 >= -len( + x_shape + ), "axis2 is out of range in diagonal_scatter (expected to be in range of [-{}, {}), but got {})".format( + len(x_shape), len(x_shape), axis2 + ) + + axis1 %= len(x_shape) + axis2 %= len(x_shape) + assert ( + axis1 != axis2 + ), "axis1 and axis2 should not be identical in diagonal_scatter, but received axis1 = {}, axis2 = {}".format( + axis1, axis2 + ) + + predshape = [] + for i in range(len(x_shape)): + if i != axis1 and i != axis2: + predshape.append(x_shape[i]) + diaglen = min( + x_shape[axis1], + x_shape[axis1] + offset, + x_shape[axis2], + x_shape[axis2] - offset, + ) + predshape.append(diaglen) + assert tuple(predshape) == tuple( + y.shape + ), f"y.shape should be {tuple(predshape)}, but received {tuple(y.shape)}" + if len(y.shape) == 1: + y = y.reshape([1, -1]) + + if in_dynamic_mode(): + return _C_ops.fill_diagonal_tensor(x, y, offset, axis1, axis2) + else: + helper = LayerHelper('diagonal_scatter', **locals()) + check_variable_and_dtype( + x, + 'X', + [ + 'float16', + 'float32', + 'float64', + 'bfloat16', + 'uint8', + 'int8', + 'int32', + 'int64', + 'bool', + 'complex64', + 'complex128', + ], + 'paddle.tensor.manipulation.diagonal_scatter', + ) + check_variable_and_dtype( + y, + 'Y', + [ + 'float16', + 'float32', + 'float64', + 'bfloat16', + 'uint8', + 'int8', + 'int32', + 'int64', + 'bool', + 'complex64', + 'complex128', + ], + 'paddle.tensor.manipulation.diagonal_scatter', + ) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='fill_diagonal_tensor', + inputs={ + 'X': x, + 'Y': y, + }, + outputs={'Out': out}, + attrs={ + 'offset': offset, + 'axis1': axis1, + 'axis2': axis2, + }, + ) + return out