Skip to content

Commit

Permalink
[Prim] reduce_as op support uint8, in8, complex64 and complex128 (Pad…
Browse files Browse the repository at this point in the history
…dlePaddle#63782)

* fix the include file

* add uint8 and int8

* suport complex64 and complex128

* fix docs

* fix conflict

* fix the third_party
  • Loading branch information
zeroRains authored and co63oc committed May 10, 2024
1 parent 661d035 commit d30f2a3
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 31 deletions.
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/reduce_as_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/reduce_as_kernel.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/reduce_grad.h"

namespace phi {
Expand Down Expand Up @@ -55,6 +55,8 @@ PD_REGISTER_KERNEL(reduce_as_grad,
int,
int64_t,
uint8_t,
int8_t) {
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/reduce_as_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

namespace phi {

Expand Down Expand Up @@ -48,4 +48,6 @@ PD_REGISTER_KERNEL(reduce_as,
int,
int64_t,
uint8_t,
int8_t) {}
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/reduce_as_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_grad_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce_grad.h"

Expand Down Expand Up @@ -65,6 +64,8 @@ PD_REGISTER_KERNEL(reduce_as_grad,
int,
int64_t,
uint8_t,
int8_t) {
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
7 changes: 4 additions & 3 deletions paddle/phi/kernels/gpu/reduce_as_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_as_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
Expand Down Expand Up @@ -47,4 +46,6 @@ PD_REGISTER_KERNEL(reduce_as,
int,
int64_t,
uint8_t,
int8_t) {}
int8_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
16 changes: 12 additions & 4 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,8 +1581,8 @@ def reduce_as(x, target, name=None):
Computes the sum of tensor elements make the shape of its result equal to the shape of target.
Args:
x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int32 or int64.
target (Tensor): An N-D Tensor, the length of x shape must greater than or equal to the length of target shape. The data type is bool, float16, float32, float64, int32 or int64.
x (Tensor): An N-D Tensor, the data type is bool, float16, float32, float64, int8, uint8, int16, uint16, int32, int64, complex64 or complex128.
target (Tensor): An N-D Tensor, the length of x shape must greater than or equal to the length of target shape. The data type is bool, float16, float32, float64, int8, uint8, int16, uint16, int32, int64, complex64 or complex128.
Returns:
Tensor: The sum of the input tensor x along some axis has the same shape as the shape of the input tensor target, if `x.dtype='bool'`, `x.dtype='int32'`, it's data type is `'int64'`, otherwise it's data type is the same as `x`.
Expand Down Expand Up @@ -1615,13 +1615,17 @@ def reduce_as(x, target, name=None):
'x',
[
'bool',
'uint16',
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'uint16',
'int32',
'int64',
'complex64',
'complex128',
],
'reduce_as',
)
Expand All @@ -1630,13 +1634,17 @@ def reduce_as(x, target, name=None):
'target',
[
'bool',
'uint16',
'float16',
'float32',
'float64',
'int8',
'uint8',
'int16',
'uint16',
'int32',
'int64',
'complex64',
'complex128',
],
'reduce_as',
)
Expand Down
55 changes: 40 additions & 15 deletions test/deprecated/legacy_test/test_reduce_as_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,20 @@ def apply_to_static(net, use_cinn, input_spec=None):
)


class TestSumAsOp(OpTest):
class TestReduceAsOp(OpTest):
def setUp(self):
self.init_dtype()
self.init_shape()
self.init_input()
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.x = np.random.random(self.shape_x) + 1j * np.random.random(
self.shape_y
)
self.y = np.random.random(self.shape_x) + 1j * np.random.random(
self.shape_y
)
else:
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.y = np.random.random(self.shape_y).astype(self.dtype)
self.init_attrs()
self.calc_output()

Expand All @@ -60,10 +69,6 @@ def init_shape(self):
self.shape_x = [10, 10, 6]
self.shape_y = [10, 6]

def init_input(self):
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.y = np.random.random(self.shape_y).astype(self.dtype)

def init_attrs(self):
self.attrs = {'dim': [0]}

Expand All @@ -84,42 +89,62 @@ def test_check_grad(self):
)


class TestSumAsOp2(TestSumAsOp):
class TestReduceAsOp2(TestReduceAsOp):
def init_type(self):
self.dtype = 'float32'


class TestSumAsOp3(TestSumAsOp):
class TestReduceAsOp3(TestReduceAsOp):
def init_type(self):
self.dtype = 'float16'


class TestSumAsOp4(TestSumAsOp):
class TestReduceAsOp4(TestReduceAsOp):
def init_type(self):
self.dtype = 'uint16'


class TestSumAsOp5(TestSumAsOp):
class TestReduceAsOp5(TestReduceAsOp):
def init_type(self):
self.dtype = 'int16'


class TestSumAsOp6(TestSumAsOp):
class TestReduceAsOp6(TestReduceAsOp):
def init_type(self):
self.dtype = 'int64'


class TestSumAsOp7(TestSumAsOp):
class TestReduceAsOp7(TestReduceAsOp):
def init_type(self):
self.dtype = 'bool'


class TestSumAsOp8(TestSumAsOp):
class TestReduceAsOp8(TestReduceAsOp):
def init_type(self):
self.dtype = 'int32'


class TestSumAsOp9(TestSumAsOp):
class TestReduceAsOp9(TestReduceAsOp):
def init_type(self):
self.dtype = 'int8'


class TestReduceAsOp10(TestReduceAsOp):
def init_type(self):
self.dtype = 'uint8'


class TestReduceAs_Complex64(TestReduceAsOp):
def init_type(self):
self.dtype = np.complex64


class TestReduceAs_Complex128(TestReduceAsOp):
def init_type(self):
self.dtype = np.complex128


class TestReduceAsOp13(TestReduceAsOp):
def init_shape(self):
self.shape_x = [10, 10, 6]
self.shape_y = [6]
Expand All @@ -128,7 +153,7 @@ def init_attrs(self):
self.attrs = {'dim': [0, 1]}


class TestSumAsDynamicShape(unittest.TestCase):
class TestReduceAsDynamicShape(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [300, 20, 100]
Expand Down

0 comments on commit d30f2a3

Please sign in to comment.