diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ae368a7ec6cd0..8634e208e15b9 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1876,6 +1876,16 @@ func : reciprocal_grad inplace : (out_grad -> x_grad) +- backward_op : reduce_as_grad + forward : reduce_as(Tensor x, Tensor target) -> Tensor(out) + args : (Tensor x, Tensor target, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : reduce_as_grad + - backward_op : relu6_grad forward : relu6 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -2407,16 +2417,6 @@ kernel : func : stanh_grad -- backward_op : sum_as_grad - forward : sum_as(Tensor x, Tensor target) -> Tensor(out) - args : (Tensor x, Tensor target, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : sum_as_grad - - backward_op : svd_grad forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh) args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c9606cbeb4ab7..6ed188b88318e 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2287,6 +2287,16 @@ inplace : (x -> out) backward : reciprocal_grad +- op : reduce_as + args : (Tensor x, Tensor target) + output : Tensor(out) + infer_meta : + func : ReduceAsInferMeta + kernel : + func : reduce_as + data_type : x + backward : reduce_as_grad + - op : reindex_graph args : (Tensor x, Tensor neighbors, Tensor count, Tensor hashtable_value, Tensor hashtable_index) output : Tensor(reindex_src), Tensor(reindex_dst), Tensor(out_nodes) @@ -2769,16 +2779,6 @@ func : stanh backward : stanh_grad -- op : sum_as - args : (Tensor x, Tensor target) - output : Tensor(out) - infer_meta : - func : SumAsInferMeta - kernel : - func : sum_as - data_type : x - backward : sum_as_grad - - op : svd args : (Tensor x, bool full_matrices = false) output : Tensor(u), Tensor(s), Tensor(vh) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c24468742b558..fac05b3f608c2 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3047,9 +3047,9 @@ void SequenceMaskInferMeta(const MetaTensor& x, y->set_dtype(out_dtype); } -void SumAsInferMeta(const MetaTensor& x, - const MetaTensor& target, - MetaTensor* out) { +void ReduceAsInferMeta(const MetaTensor& x, + const MetaTensor& target, + MetaTensor* out) { DataType out_dtype; if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32) { out_dtype = DataType::INT64; diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index cd0bce6c61ad8..e7c3c87de8098 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -524,9 +524,9 @@ void ShuffleBatchInferMeta(const MetaTensor& x, ); -void SumAsInferMeta(const MetaTensor& x, - const MetaTensor& target, - MetaTensor* out); +void ReduceAsInferMeta(const MetaTensor& x, + const MetaTensor& target, + MetaTensor* out); void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, diff --git a/paddle/phi/kernels/cpu/sum_as_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_as_grad_kernel.cc similarity index 83% rename from paddle/phi/kernels/cpu/sum_as_grad_kernel.cc rename to paddle/phi/kernels/cpu/reduce_as_grad_kernel.cc index 871921727e0a5..8789a76cfd077 100644 --- a/paddle/phi/kernels/cpu/sum_as_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_as_grad_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/sum_as_kernel.h" +#include "paddle/phi/kernels/reduce_as_kernel.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -21,11 +21,11 @@ namespace phi { template -void SumAsGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - const DenseTensor& out_grad, - DenseTensor* x_grad) { +void ReduceAsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + const DenseTensor& out_grad, + DenseTensor* x_grad) { auto reduce_dim = phi::funcs::GetReduceDims(x, target); bool reduce_all = recompute_reduce_all(x, reduce_dim); ReduceGradKernel(dev_ctx, @@ -40,10 +40,10 @@ void SumAsGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(sum_as_grad, +PD_REGISTER_KERNEL(reduce_as_grad, CPU, ALL_LAYOUT, - phi::SumAsGradKernel, + phi::ReduceAsGradKernel, bool, float, double, diff --git a/paddle/phi/kernels/cpu/sum_as_kernel.cc b/paddle/phi/kernels/cpu/reduce_as_kernel.cc similarity index 83% rename from paddle/phi/kernels/cpu/sum_as_kernel.cc rename to paddle/phi/kernels/cpu/reduce_as_kernel.cc index 562e4d9a1f394..25661bd829a20 100644 --- a/paddle/phi/kernels/cpu/sum_as_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_as_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/sum_as_kernel.h" +#include "paddle/phi/kernels/reduce_as_kernel.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -21,10 +21,10 @@ namespace phi { template -void SumAsKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - DenseTensor* out) { +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out) { auto reduce_dim = phi::funcs::GetReduceDims(x, target); bool reduce_all = recompute_reduce_all(x, reduce_dim); phi::Reduce( @@ -33,10 +33,10 @@ void SumAsKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(sum_as, +PD_REGISTER_KERNEL(reduce_as, CPU, ALL_LAYOUT, - phi::SumAsKernel, + phi::ReduceAsKernel, bool, float, double, diff --git a/paddle/phi/kernels/gpu/sum_as_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_as_grad_kernel.cu similarity index 83% rename from paddle/phi/kernels/gpu/sum_as_grad_kernel.cu rename to paddle/phi/kernels/gpu/reduce_as_grad_kernel.cu index d13b774d9e58a..cbd297326e14a 100644 --- a/paddle/phi/kernels/gpu/sum_as_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_as_grad_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/sum_as_grad_kernel.h" +#include "paddle/phi/kernels/reduce_as_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -22,11 +22,11 @@ namespace phi { template -void SumAsGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - const DenseTensor& out_grad, - DenseTensor* x_grad) { +void ReduceAsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + const DenseTensor& out_grad, + DenseTensor* x_grad) { auto reduce_dim = phi::funcs::GetReduceDims(x, target); bool reduce_all = recompute_reduce_all(x, reduce_dim); auto update_dims = common::vectorize(x.dims()); @@ -50,10 +50,10 @@ void SumAsGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(sum_as_grad, +PD_REGISTER_KERNEL(reduce_as_grad, GPU, ALL_LAYOUT, - phi::SumAsGradKernel, + phi::ReduceAsGradKernel, bool, float, double, diff --git a/paddle/phi/kernels/gpu/sum_as_kernel.cu b/paddle/phi/kernels/gpu/reduce_as_kernel.cu similarity index 82% rename from paddle/phi/kernels/gpu/sum_as_kernel.cu rename to paddle/phi/kernels/gpu/reduce_as_kernel.cu index cdab8602301e5..1555d2b59b7c4 100644 --- a/paddle/phi/kernels/gpu/sum_as_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_as_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/sum_as_kernel.h" +#include "paddle/phi/kernels/reduce_as_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -21,10 +21,10 @@ namespace phi { template -void SumAsKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - DenseTensor* out) { +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out) { auto reduce_dim = phi::funcs::GetReduceDims(x, target); dev_ctx.template Alloc(out); phi::SumKernel(dev_ctx, x, reduce_dim, out->type(), false, out); @@ -32,10 +32,10 @@ void SumAsKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(sum_as, +PD_REGISTER_KERNEL(reduce_as, GPU, ALL_LAYOUT, - phi::SumAsKernel, + phi::ReduceAsKernel, bool, float, double, diff --git a/paddle/phi/kernels/sum_as_kernel.h b/paddle/phi/kernels/reduce_as_as_kernel.h similarity index 84% rename from paddle/phi/kernels/sum_as_kernel.h rename to paddle/phi/kernels/reduce_as_as_kernel.h index cbc8813b95c46..ad62ddb6e0674 100644 --- a/paddle/phi/kernels/sum_as_kernel.h +++ b/paddle/phi/kernels/reduce_as_as_kernel.h @@ -22,9 +22,9 @@ namespace phi { template -void SumAsKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - DenseTensor* out); +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/sum_as_grad_kernel.h b/paddle/phi/kernels/reduce_as_grad_kernel.h similarity index 78% rename from paddle/phi/kernels/sum_as_grad_kernel.h rename to paddle/phi/kernels/reduce_as_grad_kernel.h index 5ae7126371c24..577af8ffb7eb9 100644 --- a/paddle/phi/kernels/sum_as_grad_kernel.h +++ b/paddle/phi/kernels/reduce_as_grad_kernel.h @@ -22,10 +22,10 @@ namespace phi { template -void SumAsGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& target, - const DenseTensor& out_grad, - DenseTensor* x_grad); +void ReduceAsGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + const DenseTensor& out_grad, + DenseTensor* x_grad); } // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7adcff2ac86ec..9b7cd40224bab 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -474,6 +474,7 @@ prod, rad2deg, reciprocal, + reduce_as, remainder, remainder_, renorm, @@ -494,7 +495,6 @@ stanh, subtract, sum, - sum_as, take, tan, tan_, @@ -847,7 +847,7 @@ 'ones', 'not_equal', 'sum', - 'sum_as', + 'reduce_as', 'nansum', 'nanmean', 'count_nonzero', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 7daa8fa530d4e..936edb9c428fb 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -357,6 +357,7 @@ rad2deg, reciprocal, reciprocal_, + reduce_as, remainder, remainder_, renorm, @@ -383,7 +384,6 @@ subtract, subtract_, sum, - sum_as, take, tan, tan_, @@ -526,7 +526,7 @@ 'square', 'stanh', 'sum', - 'sum_as', + 'reduce_as', 'multigammaln', 'multigammaln_', 'nan_to_num', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b38b6659b288d..a3f2c087f63ea 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1576,7 +1576,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): return out -def sum_as(x, target, name=None): +def reduce_as(x, target, name=None): """ Computes the sum of tensor elements make the shape of its result equal to the shape of target. @@ -1601,14 +1601,14 @@ def sum_as(x, target, name=None): >>> target Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, [1, 2, 3, 4]) - >>> res = paddle.sum_as(x, target) + >>> res = paddle.reduce_as(x, target) >>> res Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, [6 , 8 , 10, 12]) """ if in_dynamic_or_pir_mode(): - return _C_ops.sum_as(x, target) + return _C_ops.reduce_as(x, target) else: check_variable_and_dtype( x, @@ -1623,7 +1623,7 @@ def sum_as(x, target, name=None): 'int32', 'int64', ], - 'sum_as', + 'reduce_as', ) check_variable_and_dtype( target, @@ -1638,13 +1638,13 @@ def sum_as(x, target, name=None): 'int32', 'int64', ], - 'sum_as', + 'reduce_as', ) - helper = LayerHelper('sum_as', **locals()) + helper = LayerHelper('reduce_as', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( - type='sum_as', + type='reduce_as', inputs={'x': x, 'target': target}, outputs={'out': out}, ) diff --git a/test/legacy_test/test_sum_as_op.py b/test/legacy_test/test_reduce_as_op.py similarity index 94% rename from test/legacy_test/test_sum_as_op.py rename to test/legacy_test/test_reduce_as_op.py index 361625bb88fc2..53f35eba9a08b 100644 --- a/test/legacy_test/test_sum_as_op.py +++ b/test/legacy_test/test_reduce_as_op.py @@ -24,8 +24,8 @@ paddle.seed(100) -def sum_as_net(x, target): - return paddle.sum_as(x, target) +def reduce_as_net(x, target): + return paddle.reduce_as(x, target) def apply_to_static(net, use_cinn, input_spec=None): @@ -47,8 +47,8 @@ def setUp(self): self.init_attrs() self.calc_output() - self.python_api = paddle.sum_as - self.op_type = "sum_as" + self.python_api = paddle.reduce_as + self.op_type = "reduce_as" self.inputs = {'x': self.x, 'target': self.y} self.outputs = {'out': self.out} self.if_enable_cinn() @@ -131,15 +131,15 @@ def init_attrs(self): class TestSumAsDynamicShape(unittest.TestCase): def setUp(self): np.random.seed(2023) - self.shape_x = [300, 200, 100] - self.shape_y = [200, 100] + self.shape_x = [300, 20, 100] + self.shape_y = [20, 100] self.dtype_x = "float32" self.dtype_y = "float32" self.init_x_shape = [None, None, 100] self.init_y_shape = [None, 100] self.x = np.random.random(self.shape_x).astype(self.dtype_x) self.y = np.random.random(self.shape_y).astype(self.dtype_y) - self.net = sum_as_net + self.net = reduce_as_net self.enable_cinn = False self.tol = 1e-6