diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9830d7ae3a7a4..3e85e17c2efd7 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2288,6 +2288,16 @@ inplace : (x -> out) backward : reciprocal_grad +- op : reduce_as + args : (Tensor x, Tensor target) + output : Tensor(out) + infer_meta : + func : ReduceAsInferMeta + kernel : + func : reduce_as + data_type : x + backward : reduce_as_grad + - op : reindex_graph args : (Tensor x, Tensor neighbors, Tensor count, Tensor hashtable_value, Tensor hashtable_index) output : Tensor(reindex_src), Tensor(reindex_dst), Tensor(out_nodes) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 63d1d1c9b32d0..fac05b3f608c2 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3047,6 +3047,20 @@ void SequenceMaskInferMeta(const MetaTensor& x, y->set_dtype(out_dtype); } +void ReduceAsInferMeta(const MetaTensor& x, + const MetaTensor& target, + MetaTensor* out) { + DataType out_dtype; + if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32) { + out_dtype = DataType::INT64; + } else { + out_dtype = x.dtype(); + } + out->set_dtype(out_dtype); + out->set_dims(target.dims()); + out->set_layout(x.layout()); +} + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 77bc925197013..e7c3c87de8098 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -524,6 +524,10 @@ void ShuffleBatchInferMeta(const MetaTensor& x, ); +void ReduceAsInferMeta(const MetaTensor& x, + const MetaTensor& target, + MetaTensor* out); + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/reduce_as_kernel.cc b/paddle/phi/kernels/cpu/reduce_as_kernel.cc new file mode 100644 index 0000000000000..25661bd829a20 --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_as_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_as_kernel.h" + +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" + +namespace phi { + +template +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out) { + auto reduce_dim = phi::funcs::GetReduceDims(x, target); + bool reduce_all = recompute_reduce_all(x, reduce_dim); + phi::Reduce( + dev_ctx, x, reduce_all, reduce_dim, false, out->type(), out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(reduce_as, + CPU, + ALL_LAYOUT, + phi::ReduceAsKernel, + bool, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int16_t, + int, + int64_t, + uint8_t, + int8_t) {} diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 45a1024339ba3..c998b7f484fa4 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -295,5 +295,37 @@ inline void FCOutputSize(const DDim &in_dims, out_dims.push_back(w_dims1); } +inline std::vector GetReduceDims(const DenseTensor &in, + const DenseTensor &out) { + std::vector reduce_dims; + auto in_dims = in.dims(); + auto out_dims = out.dims(); + int diff = in_dims.size() - out_dims.size(); + for (int i = 0; i < diff; ++i) { + reduce_dims.push_back(i); + } + for (int i = 0; i < out_dims.size(); ++i) { + if (out_dims[i] == 1 && in_dims[i + diff] != 1) { + reduce_dims.push_back(i + diff); + } else { + PADDLE_ENFORCE_EQ( + in_dims[i + diff], + out_dims[i], + phi::errors::InvalidArgument( + "ReduceDims dimension mismatch. Operands could " + "not be broadcast together with the shape of in_dims = [%s] and " + "the shape of out_dims = [%s]. Received [%d] in X is not equal " + "to " + "[%d] in Y at i:%d.", + in_dims, + out_dims, + in_dims[i + diff], + out_dims[i], + i)); + } + } + return reduce_dims; +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_as_kernel.cu b/paddle/phi/kernels/gpu/reduce_as_kernel.cu new file mode 100644 index 0000000000000..1555d2b59b7c4 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_as_kernel.cu @@ -0,0 +1,48 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_as_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out) { + auto reduce_dim = phi::funcs::GetReduceDims(x, target); + dev_ctx.template Alloc(out); + phi::SumKernel(dev_ctx, x, reduce_dim, out->type(), false, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(reduce_as, + GPU, + ALL_LAYOUT, + phi::ReduceAsKernel, + bool, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int16_t, + int, + int64_t, + uint8_t, + int8_t) {} diff --git a/paddle/phi/kernels/reduce_as_kernel.h b/paddle/phi/kernels/reduce_as_kernel.h new file mode 100644 index 0000000000000..ad62ddb6e0674 --- /dev/null +++ b/paddle/phi/kernels/reduce_as_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" + +namespace phi { + +template +void ReduceAsKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& target, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ab4d932278093..02b666aabefbc 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -475,6 +475,7 @@ prod, rad2deg, reciprocal, + reduce_as, remainder, remainder_, renorm, @@ -846,6 +847,7 @@ 'all', 'ones', 'not_equal', + 'reduce_as', 'sum', 'nansum', 'nanmean', diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index 1c5c12c94a6ae..543091f102548 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -86,7 +86,7 @@ def convert_np_dtype_to_dtype_(np_dtype): """ # Convert the data type string to numpy data type. - if np_dtype == "bfloat16": + if isinstance(np_dtype, str) and np_dtype == "bfloat16": # since there is still no support for bfloat16 in NumPy, # uint16 is used for casting bfloat16 dtype = np.dtype("uint16") diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 3afdca0fb21ce..d505a891f9c53 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -357,6 +357,7 @@ rad2deg, reciprocal, reciprocal_, + reduce_as, remainder, remainder_, renorm, @@ -524,6 +525,7 @@ 'sqrt_', 'square', 'stanh', + 'reduce_as', 'sum', 'multigammaln', 'multigammaln_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bcee27d687c73..e9e065570fe2a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1576,6 +1576,76 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): return out +def reduce_as(x, target, name=None): + """ + Computes the sum of tensor elements make the shape of its result equal to the shape of target. + Args: + x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int32 or int64. + target (Tensor): An N-D Tensor, the length of x shape must greater than or equal to the length of target shape. The data type is bool, float16, float32, float64, int32 or int64. + Returns: + Tensor: The sum of the input tensor x along some axis has the same shape as the shape of the input tensor target, if `x.dtype='bool'`, `x.dtype='int32'`, it's data type is `'int64'`, otherwise it's data type is the same as `x`. + Examples: + .. code-block:: python + >>> import paddle + >>> x = paddle.to_tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> x + Tensor(shape=[2, 4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[1, 2, 3, 4], + [5, 6, 7, 8]]) + >>> target = paddle.to_tensor([1, 2, 3, 4]) + >>> target + Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [1, 2, 3, 4]) + >>> res = paddle.reduce_as(x, target) + >>> res + Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [6 , 8 , 10, 12]) + """ + + if in_dynamic_or_pir_mode(): + return _C_ops.reduce_as(x, target) + else: + check_variable_and_dtype( + x, + 'x', + [ + 'bool', + 'uint16', + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], + 'reduce_as', + ) + check_variable_and_dtype( + target, + 'target', + [ + 'bool', + 'uint16', + 'float16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], + 'reduce_as', + ) + + helper = LayerHelper('reduce_as', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='reduce_as', + inputs={'x': x, 'target': target}, + outputs={'out': out}, + ) + return out + + def nan_to_num(x, nan=0.0, posinf=None, neginf=None, name=None): """ Replaces NaN, positive infinity, and negative infinity values in input tensor. diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 1e6a577901b48..4d026ac680097 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1442,3 +1442,4 @@ set_pit_tests_properties() set_tests_properties(test_fractional_max_pool2d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_fractional_max_pool3d_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_reduce_as_op PROPERTIES TIMEOUT 30) diff --git a/test/legacy_test/test_reduce_as_op.py b/test/legacy_test/test_reduce_as_op.py new file mode 100644 index 0000000000000..68bcbec5b0984 --- /dev/null +++ b/test/legacy_test/test_reduce_as_op.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle + +np.random.seed(100) +paddle.seed(100) + + +def reduce_as_net(x, target): + return paddle.reduce_as(x, target) + + +def apply_to_static(net, use_cinn, input_spec=None): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static( + net, + input_spec=input_spec, + build_strategy=build_strategy, + full_graph=True, + ) + + +class TestSumAsOp(OpTest): + def setUp(self): + self.init_dtype() + self.init_shape() + self.init_input() + self.init_attrs() + self.calc_output() + + self.python_api = paddle.reduce_as + self.op_type = "reduce_as" + self.inputs = {'x': self.x, 'target': self.y} + self.outputs = {'out': self.out} + self.if_enable_cinn() + + def init_dtype(self): + self.dtype = np.float64 + + def init_shape(self): + self.shape_x = [10, 10, 6] + self.shape_y = [10, 6] + + def init_input(self): + self.x = np.random.random(self.shape_x).astype(self.dtype) + self.y = np.random.random(self.shape_y).astype(self.dtype) + + def init_attrs(self): + self.attrs = {'dim': [0]} + + def if_enable_cinn(self): + pass + + def calc_output(self): + self.out = self.x.sum(axis=tuple(self.attrs['dim'])) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + pass + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()