Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] support input 0D Tensor for some binary api #46909

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/common_infer_shape_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_npu.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx,
phi::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
phi::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down Expand Up @@ -394,7 +394,7 @@ void ElementwiseCompute(const CPUContext &dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down Expand Up @@ -1725,7 +1725,7 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/xpu/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void XPUElementwise(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down Expand Up @@ -121,7 +121,7 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
PADDLE_ENFORCE_LE(axis,
max_dim,
errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
Expand Down
60 changes: 59 additions & 1 deletion python/paddle/fluid/tests/unittests/test_bitwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ def init_bound(self):
self.high = 100


class TestBitwiseAnd_ZeroDim1(TestBitwiseAnd):
def init_shape(self):
self.x_shape = []
self.y_shape = []


class TestBitwiseAnd_ZeroDim2(TestBitwiseAnd):
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = []


class TestBitwiseAnd_ZeroDim3(TestBitwiseAnd):
def init_shape(self):
self.x_shape = []
self.y_shape = [2, 3, 4, 5]


class TestBitwiseAndUInt8(TestBitwiseAnd):
def init_dtype(self):
self.dtype = np.uint8
Expand Down Expand Up @@ -143,6 +161,24 @@ def init_bound(self):
self.high = 100


class TestBitwiseOr_ZeroDim1(TestBitwiseOr):
def init_shape(self):
self.x_shape = []
self.y_shape = []


class TestBitwiseOr_ZeroDim2(TestBitwiseOr):
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = []


class TestBitwiseOr_ZeroDim3(TestBitwiseOr):
def init_shape(self):
self.x_shape = []
self.y_shape = [2, 3, 4, 5]


class TestBitwiseOrUInt8(TestBitwiseOr):
def init_dtype(self):
self.dtype = np.uint8
Expand Down Expand Up @@ -229,6 +265,24 @@ def init_bound(self):
self.high = 100


class TestBitwiseXor_ZeroDim1(TestBitwiseXor):
def init_shape(self):
self.x_shape = []
self.y_shape = []


class TestBitwiseXor_ZeroDim2(TestBitwiseXor):
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = []


class TestBitwiseXor_ZeroDim3(TestBitwiseXor):
def init_shape(self):
self.x_shape = []
self.y_shape = [2, 3, 4, 5]


class TestBitwiseXorUInt8(TestBitwiseXor):
def init_dtype(self):
self.dtype = np.uint8
Expand Down Expand Up @@ -311,6 +365,11 @@ def init_bound(self):
self.high = 100


class TestBitwiseNot_ZeroDim(TestBitwiseNot):
def init_shape(self):
self.x_shape = []


class TestBitwiseNotUInt8(TestBitwiseNot):
def init_dtype(self):
self.dtype = np.uint8
Expand All @@ -334,7 +393,6 @@ def init_dtype(self):

def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]


class TestBitwiseNotInt64(TestBitwiseNot):
Expand Down
48 changes: 48 additions & 0 deletions python/paddle/fluid/tests/unittests/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,54 @@ def test_dynamic_api_bool(self):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_zero_dim_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)

def test_zero_dim_api_2(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
y = paddle.randint(-3, 3, shape=[], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)

def test_zero_dim_api_3(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.randint(-3, 3, shape=[], dtype='int32')
y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
(
x_np,
y_np,
res,
) = exe.run(fetch_list=[x, y, out])
real_result = callback(x_np, y_np)
self.assertEqual((res == real_result).all(), True)

def test_broadcast_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ def init_axis(self):
self.axis = -1


class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.add(self.x, self.y)


class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.add(self.x, self.y)


class TestElementwiseAddOp_ZeroDim3(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, []).astype(self.dtype)
self.out = np.add(self.x, self.y)


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,42 @@ def test_check_gradient(self):
self.check_grad_with_place(*check_args, **check_kwargs)


class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp):
def init_shape(self):
self.x_shape = []
self.y_shape = []


class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def init_shape(self):
self.x_shape = [13, 17]
self.y_shape = []

def compute_output(self, x, y):
return x / y.reshape([1, 1])

def compute_gradient_x(self, grad_out, y):
return grad_out / y.reshape([1, 1])

def compute_gradient_y(self, grad_out, out, y):
return np.sum(-1 * grad_out * out / y.reshape([1, 1]))


class TestElementwiseDivOp_ZeroDim3(ElementwiseDivOp):
def init_shape(self):
self.x_shape = []
self.y_shape = [13, 17]

def compute_output(self, x, y):
return x.reshape([1, 1]) / y

def compute_gradient_x(self, grad_out, y):
return np.sum(grad_out / y)

def compute_gradient_y(self, grad_out, out, y):
return -1 * grad_out * out / y


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def init_axis(self):
pass


class TestElementwiseFloorDivOp_ZeroDim1(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, []).astype(self.dtype)
self.y = np.random.uniform(0, 1000, []).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)


class TestElementwiseFloorDivOp_ZeroDim2(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, []).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)


class TestElementwiseFloorDivOp_ZeroDim3(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, []).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype)
self.out = np.floor_divide(self.x, self.y)


class TestElementwiseModOp_scalar(TestElementwiseModOp):
def init_input_output(self):
scale_x = random.randint(0, 100000000)
Expand Down
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ def test_check_grad_ingore_y(self):
)


class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}


class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}


class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}


@unittest.skipIf(
core.is_compiled_with_cuda()
and (
Expand Down
Loading