diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 835046c1e7911..ac3ec47f8297e 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -343,6 +343,16 @@ backend : x inplace: (x -> out) +- op : bitwise_left_shift + args : (Tensor x, Tensor y, bool is_arithmetic = true) + output : Tensor(out) + infer_meta : + func : BitwiseShiftInferMeta + kernel : + func : bitwise_left_shift + backend : x + inplace: (x -> out) + - op : bitwise_not args : (Tensor x) output : Tensor(out) @@ -364,6 +374,16 @@ backend : x inplace: (x -> out) +- op : bitwise_right_shift + args : (Tensor x, Tensor y, bool is_arithmetic = true) + output : Tensor(out) + infer_meta : + func : BitwiseShiftInferMeta + kernel : + func : bitwise_right_shift + backend : x + inplace: (x -> out) + - op : bitwise_xor args : (Tensor x, Tensor y) output : Tensor(out) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index b771fba031317..49aa257286754 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1269,6 +1269,13 @@ void ElementwiseInferMeta(const MetaTensor& x, return ElementwiseRawInferMeta(x, y, -1, out); } +void BitwiseShiftInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool is_arithmetic, + MetaTensor* out) { + return ElementwiseRawInferMeta(x, y, -1, out); +} + void ElementwiseRawInferMeta(const MetaTensor& x, const MetaTensor& y, int axis, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 82f5fc64d57a5..ba93038e3702d 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -231,6 +231,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta, MetaTensor* out, MetaConfig config = MetaConfig()); +void BitwiseShiftInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool is_arithmetic, + MetaTensor* out); + void EmbeddingInferMeta(const MetaTensor& x, const MetaTensor& weight, int64_t padding_idx, diff --git a/paddle/phi/kernels/bitwise_kernel.h b/paddle/phi/kernels/bitwise_kernel.h index 17307004f360e..92fc41531da97 100644 --- a/paddle/phi/kernels/bitwise_kernel.h +++ b/paddle/phi/kernels/bitwise_kernel.h @@ -41,4 +41,18 @@ void BitwiseNotKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +template +void BitwiseLeftShiftKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool is_arithmetic, + DenseTensor* out); + +template +void BitwiseRightShiftKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool is_arithmetic, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/bitwise_kernel.cc b/paddle/phi/kernels/cpu/bitwise_kernel.cc index a6297efd9cd3e..4de0d575e96e4 100644 --- a/paddle/phi/kernels/cpu/bitwise_kernel.cc +++ b/paddle/phi/kernels/cpu/bitwise_kernel.cc @@ -40,6 +40,45 @@ DEFINE_BITWISE_KERNEL(Or) DEFINE_BITWISE_KERNEL(Xor) #undef DEFINE_BITWISE_KERNEL +#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + bool is_arithmetic, \ + DenseTensor* out) { \ + auto x_dims = x.dims(); \ + auto y_dims = y.dims(); \ + if (x_dims.size() >= y_dims.size()) { \ + if (is_arithmetic) { \ + funcs::Bitwise##op_type##ArithmeticFunctor func; \ + funcs::ElementwiseCompute< \ + funcs::Bitwise##op_type##ArithmeticFunctor, \ + T>(dev_ctx, x, y, func, out); \ + } else { \ + funcs::Bitwise##op_type##LogicFunctor func; \ + funcs::ElementwiseCompute, \ + T>(dev_ctx, x, y, func, out); \ + } \ + } else { \ + if (is_arithmetic) { \ + funcs::InverseBitwise##op_type##ArithmeticFunctor inv_func; \ + funcs::ElementwiseCompute< \ + funcs::InverseBitwise##op_type##ArithmeticFunctor, \ + T>(dev_ctx, x, y, inv_func, out); \ + } else { \ + funcs::InverseBitwise##op_type##LogicFunctor inv_func; \ + funcs::ElementwiseCompute< \ + funcs::InverseBitwise##op_type##LogicFunctor, \ + T>(dev_ctx, x, y, inv_func, out); \ + } \ + } \ + } + +DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShift) +DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShift) +#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE + template void BitwiseNotKernel(const Context& dev_ctx, const DenseTensor& x, @@ -97,3 +136,23 @@ PD_REGISTER_KERNEL(bitwise_not, int16_t, int, int64_t) {} + +PD_REGISTER_KERNEL(bitwise_left_shift, + CPU, + ALL_LAYOUT, + phi::BitwiseLeftShiftKernel, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_right_shift, + CPU, + ALL_LAYOUT, + phi::BitwiseRightShiftKernel, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/bitwise_functors.h b/paddle/phi/kernels/funcs/bitwise_functors.h index db1fc59f534bc..35811f1fca172 100644 --- a/paddle/phi/kernels/funcs/bitwise_functors.h +++ b/paddle/phi/kernels/funcs/bitwise_functors.h @@ -47,5 +47,164 @@ struct BitwiseNotFunctor { HOSTDEVICE bool operator()(const bool a) const { return !a; } }; +template +struct BitwiseLeftShiftArithmeticFunctor { + HOSTDEVICE T operator()(const T a, const T b) const { + if (b >= static_cast(sizeof(T) * 8)) return static_cast(0); + if (b < static_cast(0)) return static_cast(0); + return a << b; + } +}; + +template +struct InverseBitwiseLeftShiftArithmeticFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + if (a >= static_cast(sizeof(T) * 8)) return static_cast(0); + if (a < static_cast(0)) return static_cast(0); + return b << a; + } +}; + +template +struct BitwiseLeftShiftLogicFunctor { + HOSTDEVICE T operator()(const T a, const T b) const { + if (b < static_cast(0) || b >= static_cast(sizeof(T) * 8)) + return static_cast(0); + return a << b; + } +}; + +template +struct InverseBitwiseLeftShiftLogicFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + if (a < static_cast(0) || a >= static_cast(sizeof(T) * 8)) + return static_cast(0); + return b << a; + } +}; + +template +struct BitwiseRightShiftArithmeticFunctor { + HOSTDEVICE T operator()(const T a, const T b) const { + if (b < static_cast(0) || b >= static_cast(sizeof(T) * 8)) + return static_cast(-(a >> (sizeof(T) * 8 - 1) & 1)); + return a >> b; + } +}; + +template +struct InverseBitwiseRightShiftArithmeticFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + if (a < static_cast(0) || a >= static_cast(sizeof(T) * 8)) + return static_cast(-(b >> (sizeof(T) * 8 - 1) & 1)); + return b >> a; + } +}; + +template <> +struct BitwiseRightShiftArithmeticFunctor { + HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const { + if (b >= static_cast(sizeof(uint8_t) * 8)) + return static_cast(0); + return a >> b; + } +}; + +template <> +struct InverseBitwiseRightShiftArithmeticFunctor { + inline HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const { + if (a >= static_cast(sizeof(uint8_t) * 8)) + return static_cast(0); + return b >> a; + } +}; + +template +struct BitwiseRightShiftLogicFunctor { + HOSTDEVICE T operator()(const T a, const T b) const { + if (b >= static_cast(sizeof(T) * 8) || b < static_cast(0)) + return static_cast(0); + return a >> b; + } +}; + +template +struct InverseBitwiseRightShiftLogicFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + if (a >= static_cast(sizeof(T) * 8) || a < static_cast(0)) + return static_cast(0); + return b >> a; + } +}; + +template +HOSTDEVICE T logic_shift_func(const T a, const T b) { + if (b < static_cast(0) || b >= static_cast(sizeof(T) * 8)) + return static_cast(0); + T t = static_cast(sizeof(T) * 8 - 1); + T mask = (((a >> t) << t) >> b) << 1; + return (a >> b) ^ mask; +} + +// signed int8 +template <> +struct BitwiseRightShiftLogicFunctor { + HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const { + return logic_shift_func(a, b); + } +}; + +template <> +struct InverseBitwiseRightShiftLogicFunctor { + inline HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const { + return logic_shift_func(b, a); + } +}; + +// signed int16 +template <> +struct BitwiseRightShiftLogicFunctor { + HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const { + return logic_shift_func(a, b); + } +}; + +template <> +struct InverseBitwiseRightShiftLogicFunctor { + inline HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const { + return logic_shift_func(b, a); + } +}; + +// signed int32 +template <> +struct BitwiseRightShiftLogicFunctor { + HOSTDEVICE int operator()(const int a, const int b) const { + return logic_shift_func(a, b); + } +}; + +template <> +struct InverseBitwiseRightShiftLogicFunctor { + inline HOSTDEVICE int operator()(const int a, const int b) const { + return logic_shift_func(b, a); + } +}; + +// signed int64 +template <> +struct BitwiseRightShiftLogicFunctor { + HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { + return logic_shift_func(a, b); + } +}; + +template <> +struct InverseBitwiseRightShiftLogicFunctor { + inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { + return logic_shift_func(b, a); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/kps/bitwise_kernel.cu b/paddle/phi/kernels/kps/bitwise_kernel.cu index fcdc7c95e9151..c4fa2b49f7330 100644 --- a/paddle/phi/kernels/kps/bitwise_kernel.cu +++ b/paddle/phi/kernels/kps/bitwise_kernel.cu @@ -43,6 +43,29 @@ DEFINE_BITWISE_KERNEL(Or) DEFINE_BITWISE_KERNEL(Xor) #undef DEFINE_BITWISE_KERNEL +#define DEFINE_BITWISE_KERNEL_WITH_BOOL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + bool is_arithmetic, \ + DenseTensor* out) { \ + dev_ctx.template Alloc(out); \ + std::vector ins = {&x, &y}; \ + std::vector outs = {out}; \ + if (is_arithmetic) { \ + funcs::Bitwise##op_type##ArithmeticFunctor func; \ + funcs::BroadcastKernel(dev_ctx, ins, &outs, func); \ + } else { \ + funcs::Bitwise##op_type##LogicFunctor func; \ + funcs::BroadcastKernel(dev_ctx, ins, &outs, func); \ + } \ + } + +DEFINE_BITWISE_KERNEL_WITH_BOOL(LeftShift) +DEFINE_BITWISE_KERNEL_WITH_BOOL(RightShift) +#undef DEFINE_BITWISE_KERNEL_WITH_BOOL + template void BitwiseNotKernel(const Context& dev_ctx, const DenseTensor& x, @@ -112,4 +135,24 @@ PD_REGISTER_KERNEL(bitwise_not, int, int64_t) {} +PD_REGISTER_KERNEL(bitwise_left_shift, + KPS, + ALL_LAYOUT, + phi::BitwiseLeftShiftKernel, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_right_shift, + KPS, + ALL_LAYOUT, + phi::BitwiseRightShiftKernel, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + #endif diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ae35e62755356..153fb194c1177 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -358,6 +358,10 @@ atan_, atanh, atanh_, + bitwise_left_shift, + bitwise_left_shift_, + bitwise_right_shift, + bitwise_right_shift_, broadcast_shape, ceil, clip, @@ -944,6 +948,10 @@ 'i1e', 'polygamma', 'polygamma_', + 'bitwise_left_shift', + 'bitwise_left_shift_', + 'bitwise_right_shift', + 'bitwise_right_shift_', 'masked_fill', 'masked_fill_', 'masked_scatter', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b718910348d8f..330a5132e06f2 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -235,6 +235,10 @@ atan_, atanh, atanh_, + bitwise_left_shift, + bitwise_left_shift_, + bitwise_right_shift, + bitwise_right_shift_, broadcast_shape, ceil, ceil_, @@ -781,6 +785,11 @@ 'asinh_', 'diag', 'normal_', + 'normal_', + 'bitwise_left_shift', + 'bitwise_left_shift_', + 'bitwise_right_shift', + 'bitwise_right_shift_', 'index_fill', 'index_fill_', 'atleast_1d', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6d75d41b4949c..94d685f299edf 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7101,6 +7101,192 @@ def ldexp_(x, y, name=None): return paddle.multiply_(x, paddle.pow(two, y)) +def _bitwise_op(op_name, x, y, is_arithmetic, out=None, name=None): + check_variable_and_dtype( + x, + "x", + ["uint8", "int8", "int16", "int32", "int64"], + op_name, + ) + if y is not None: + check_variable_and_dtype( + y, + "y", + ["uint8", "int8", "int16", "int32", "int64"], + op_name, + ) + + helper = LayerHelper(op_name, **locals()) + assert x.dtype == y.dtype + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type=op_name, + inputs={"x": x, "y": y}, + outputs={"out": out}, + attrs={'is_arithmetic': is_arithmetic}, + ) + + return out + + +def bitwise_left_shift(x, y, is_arithmetic=True, out=None, name=None): + r""" + Apply ``bitwise_left_shift`` on Tensor ``X`` and ``Y`` . + + .. math:: + + Out = X \ll Y + + .. note:: + + ``paddle.bitwise_left_shift`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ . + + .. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor + + Args: + x (Tensor): Input Tensor of ``bitwise_left_shift`` . It is a N-D Tensor of uint8, int8, int16, int32, int64. + y (Tensor): Input Tensor of ``bitwise_left_shift`` . It is a N-D Tensor of uint8, int8, int16, int32, int64. + is_arithmetic (bool, optional): A boolean indicating whether to choose arithmetic shift, if False, means logic shift. Default True. + out (Tensor, optional): Result of ``bitwise_left_shift`` . It is a N-D Tensor with the same data type of input Tensor. Default: None. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Result of ``bitwise_left_shift`` . It is a N-D Tensor with the same data type of input Tensor. + + Examples: + .. code-block:: python + :name: bitwise_left_shift_example1 + + >>> import paddle + >>> x = paddle.to_tensor([[1,2,4,8],[16,17,32,65]]) + >>> y = paddle.to_tensor([[1,2,3,4,], [2,3,2,1]]) + >>> paddle.bitwise_left_shift(x, y, is_arithmetic=True) + Tensor(shape=[2, 4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[2 , 8 , 32 , 128], + [64 , 136, 128, 130]]) + + .. code-block:: python + :name: bitwise_left_shift_example2 + + >>> import paddle + >>> x = paddle.to_tensor([[1,2,4,8],[16,17,32,65]]) + >>> y = paddle.to_tensor([[1,2,3,4,], [2,3,2,1]]) + >>> paddle.bitwise_left_shift(x, y, is_arithmetic=False) + Tensor(shape=[2, 4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[2 , 8 , 32 , 128], + [64 , 136, 128, 130]]) + """ + if in_dynamic_mode() and out is None: + return _C_ops.bitwise_left_shift(x, y, is_arithmetic) + return _bitwise_op( + op_name="bitwise_left_shift", + x=x, + y=y, + is_arithmetic=is_arithmetic, + name=name, + out=out, + ) + + +@inplace_apis_in_dygraph_only +def bitwise_left_shift_(x, y, is_arithmetic=True, out=None, name=None): + r""" + Inplace version of ``bitwise_left_shift`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_bitwise_left_shift`. + """ + out_shape = broadcast_shape(x.shape, y.shape) + if out_shape != x.shape: + raise ValueError( + "The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format( + out_shape, x.shape + ) + ) + if in_dynamic_or_pir_mode(): + return _C_ops.bitwise_left_shift_(x, y, is_arithmetic) + + +def bitwise_right_shift(x, y, is_arithmetic=True, out=None, name=None): + r""" + Apply ``bitwise_right_shift`` on Tensor ``X`` and ``Y`` . + + .. math:: + + Out = X \gg Y + + .. note:: + + ``paddle.bitwise_right_shift`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ . + + .. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor + + Args: + x (Tensor): Input Tensor of ``bitwise_right_shift`` . It is a N-D Tensor of uint8, int8, int16, int32, int64. + y (Tensor): Input Tensor of ``bitwise_right_shift`` . It is a N-D Tensor of uint8, int8, int16, int32, int64. + is_arithmetic (bool, optional): A boolean indicating whether to choose arithmetic shift, if False, means logic shift. Default True. + out (Tensor, optional): Result of ``bitwise_right_shift`` . It is a N-D Tensor with the same data type of input Tensor. Default: None. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Result of ``bitwise_right_shift`` . It is a N-D Tensor with the same data type of input Tensor. + + Examples: + .. code-block:: python + :name: bitwise_right_shift_example1 + + >>> import paddle + >>> x = paddle.to_tensor([[10,20,40,80],[16,17,32,65]]) + >>> y = paddle.to_tensor([[1,2,3,4,], [2,3,2,1]]) + >>> paddle.bitwise_right_shift(x, y, is_arithmetic=True) + Tensor(shape=[2, 4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[5 , 5 , 5 , 5 ], + [4 , 2 , 8 , 32]]) + + .. code-block:: python + :name: bitwise_right_shift_example2 + + >>> import paddle + >>> x = paddle.to_tensor([[-10,-20,-40,-80],[-16,-17,-32,-65]], dtype=paddle.int8) + >>> y = paddle.to_tensor([[1,2,3,4,], [2,3,2,1]], dtype=paddle.int8) + >>> paddle.bitwise_right_shift(x, y, is_arithmetic=False) # logic shift + Tensor(shape=[2, 4], dtype=int8, place=Place(gpu:0), stop_gradient=True, + [[123, 59 , 27 , 11 ], + [60 , 29 , 56 , 95 ]]) + """ + if in_dynamic_mode() and out is None: + return _C_ops.bitwise_right_shift(x, y, is_arithmetic) + + return _bitwise_op( + op_name="bitwise_right_shift", + x=x, + y=y, + is_arithmetic=is_arithmetic, + name=name, + out=out, + ) + + +@inplace_apis_in_dygraph_only +def bitwise_right_shift_(x, y, is_arithmetic=True, out=None, name=None): + r""" + Inplace version of ``bitwise_right_shift`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_bitwise_left_shift`. + """ + out_shape = broadcast_shape(x.shape, y.shape) + if out_shape != x.shape: + raise ValueError( + "The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format( + out_shape, x.shape + ) + ) + + if in_dynamic_or_pir_mode(): + return _C_ops.bitwise_right_shift_(x, y, is_arithmetic) + + def hypot(x, y, name=None): """ Calculate the length of the hypotenuse of a right-angle triangle. The equation is: diff --git a/test/legacy_test/test_bitwise_shift_op.py b/test/legacy_test/test_bitwise_shift_op.py new file mode 100644 index 0000000000000..14ddbeb63ea1b --- /dev/null +++ b/test/legacy_test/test_bitwise_shift_op.py @@ -0,0 +1,462 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + +_SIGNED_TO_UNSIGNED_TABLE = { + "int8": "uint8", + "int16": "uint16", + "int32": "uint32", + "int64": "uint64", +} + +_UNSIGNED_TO_SIGNED_TABLE = { + "uint8": "int8", + "uint16": "int16", + "uint32": "int32", + "uint64": "int64", +} + +_UNSIGNED_LIST = ['uint8', 'uint16', 'uint32', 'uint64'] + + +def ref_left_shift_arithmetic(x, y): + out = np.left_shift(x, y) + return out + + +def ref_left_shift_logical(x, y): + out = np.left_shift(x, y) + return out + + +def ref_right_shift_arithmetic(x, y): + return np.right_shift(x, y) + + +def ref_right_shift_logical(x, y): + if str(x.dtype) in _UNSIGNED_LIST: + return np.right_shift(x, y) + else: + orig_dtype = x.dtype + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[str(orig_dtype)] + x = x.astype(unsigned_dtype) + y = y.astype(unsigned_dtype) + res = np.right_shift(x, y) + return res.astype(orig_dtype) + + +class TestBitwiseLeftShiftAPI(unittest.TestCase): + def setUp(self): + self.init_input() + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + def test_static_api_arithmetic(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out = paddle.bitwise_left_shift( + x, + y, + ) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_left_shift_arithmetic(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0]) + + def test_dygraph_api_arithmetic(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.bitwise_left_shift( + x, + y, + ) + out_ref = ref_left_shift_arithmetic(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy()) + paddle.enable_static() + + def test_static_api_logical(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out = paddle.bitwise_left_shift(x, y, False) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_left_shift_logical(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0]) + + def test_dygraph_api_logical(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.bitwise_left_shift(x, y, False) + out_ref = ref_left_shift_logical(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy()) + paddle.enable_static() + + +class TestBitwiseLeftShiftAPI_UINT8(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + +class TestBitwiseLeftShiftAPI_UINT8_broadcast1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [300]).astype('uint8') + + +class TestBitwiseLeftShiftAPI_UINT8_broadcast2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + +class TestBitwiseLeftShiftAPI_INT8(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + + +class TestBitwiseLeftShiftAPI_INT8_broadcast1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [300]).astype('int8') + + +class TestBitwiseLeftShiftAPI_INT8_broadcast2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + + +class TestBitwiseLeftShiftAPI_INT16(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + self.y = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + + +class TestBitwiseLeftShiftAPI_INT16_broadcast1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + self.y = np.random.randint(-(2**15), 2**15, [300]).astype('int16') + + +class TestBitwiseLeftShiftAPI_INT16_broadcast2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [300]).astype('int16') + self.y = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + + +class TestBitwiseLeftShiftAPI_INT32(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + self.y = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + + +class TestBitwiseLeftShiftAPI_INT32_broadcast1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + self.y = np.random.randint(-(2**31), 2**31, [300]).astype('int32') + + +class TestBitwiseLeftShiftAPI_INT32_broadcast2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [300]).astype('int32') + self.y = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + + +class TestBitwiseLeftShiftAPI_INT64(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + self.y = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + + +class TestBitwiseLeftShiftAPI_INT64_broadcast1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + self.y = np.random.randint(-(2**63), 2**63, [300], dtype=np.int64) + + +class TestBitwiseLeftShiftAPI_INT64_broadcast2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**63), 2**63, [300], dtype=np.int64) + self.y = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + + +class TestBitwiseLeftShiftAPI_special_case1(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='int8') + self.y = np.array([1], dtype='int8') + + +class TestBitwiseLeftShiftAPI_special_case2(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='int8') + self.y = np.array([10], dtype='int8') + + +class TestBitwiseLeftShiftAPI_special_case3(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='uint8') + self.y = np.array([1], dtype='uint8') + + +class TestBitwiseLeftShiftAPI_special_case4(TestBitwiseLeftShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='uint8') + self.y = np.array([10], dtype='uint8') + + +class TestBitwiseRightShiftAPI(unittest.TestCase): + def setUp(self): + self.init_input() + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + def test_static_api_arithmetic(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out = paddle.bitwise_right_shift( + x, + y, + ) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_right_shift_arithmetic(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0]) + + def test_dygraph_api_arithmetic(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.bitwise_right_shift( + x, + y, + ) + out_ref = ref_right_shift_arithmetic(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy()) + paddle.enable_static() + + def test_static_api_logical(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out = paddle.bitwise_right_shift(x, y, False) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_right_shift_logical(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0]) + + def test_dygraph_api_logical(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.bitwise_right_shift(x, y, False) + out_ref = ref_right_shift_logical(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy()) + paddle.enable_static() + + +class TestBitwiseRightShiftAPI_UINT8(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + +class TestBitwiseRightShiftAPI_UINT8_broadcast1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [200, 300]).astype('uint8') + self.y = np.random.randint(0, 256, [300]).astype('uint8') + + +class TestBitwiseRightShiftAPI_UINT8_broadcast2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(0, 256, [300]).astype('uint8') + self.y = np.random.randint(0, 256, [200, 300]).astype('uint8') + + +class TestBitwiseRightShiftAPI_INT8(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + + +class TestBitwiseRightShiftAPI_INT8_broadcast1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [300]).astype('int8') + + +class TestBitwiseRightShiftAPI_INT8_broadcast2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**7), 2**7, [300]).astype('int8') + self.y = np.random.randint(-(2**7), 2**7, [200, 300]).astype('int8') + + +class TestBitwiseRightShiftAPI_INT16(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + self.y = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + + +class TestBitwiseRightShiftAPI_INT16_broadcast1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + self.y = np.random.randint(-(2**15), 2**15, [300]).astype('int16') + + +class TestBitwiseRightShiftAPI_INT16_broadcast2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**15), 2**15, [300]).astype('int16') + self.y = np.random.randint(-(2**15), 2**15, [200, 300]).astype( + 'int16' + ) + + +class TestBitwiseRightShiftAPI_INT32(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + self.y = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + + +class TestBitwiseRightShiftAPI_INT32_broadcast1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + self.y = np.random.randint(-(2**31), 2**31, [300]).astype('int32') + + +class TestBitwiseRightShiftAPI_INT32_broadcast2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**31), 2**31, [300]).astype('int32') + self.y = np.random.randint(-(2**31), 2**31, [200, 300]).astype( + 'int32' + ) + + +class TestBitwiseRightShiftAPI_INT64(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + self.y = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + + +class TestBitwiseRightShiftAPI_INT64_broadcast1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + self.y = np.random.randint(-(2**63), 2**63, [300], dtype=np.int64) + + +class TestBitwiseRightShiftAPI_INT64_broadcast2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.random.randint(-(2**63), 2**63, [300], dtype=np.int64) + self.y = np.random.randint( + -(2**63), 2**63, [200, 300], dtype=np.int64 + ) + + +class TestBitwiseRightShiftAPI_special_case1(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.array([0b11111111]).astype('int8') + self.y = np.array([1]).astype('int8') + + +class TestBitwiseRightShiftAPI_special_case2(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.array([0b11111111]).astype('int8') + self.y = np.array([10]).astype('int8') + + +class TestBitwiseRightShiftAPI_special_case3(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='uint8') + self.y = np.array([1], dtype='uint8') + + +class TestBitwiseRightShiftAPI_special_case4(TestBitwiseRightShiftAPI): + def init_input(self): + self.x = np.array([0b11111111], dtype='uint8') + self.y = np.array([10], dtype='uint8') + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 38fbac0357d6d..ce2526638e450 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -1641,6 +1641,108 @@ def test_forward_version(self): self.assertEqual(var.inplace_version, 3) +class TestDygraphInplaceBitwiseLeftShift_arithmetic(TestDygraphInplaceLogicAnd): + def init_data(self): + self.input_var_numpy = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.input_var_numpy = paddle.to_tensor(self.input_var_numpy) + self.dtype = "int32" + self.y = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.y = paddle.to_tensor(self.y) + self.is_arithmetic = True + + def inplace_api_processing(self, var): + return paddle.bitwise_left_shift_(var, self.y, self.is_arithmetic) + + def non_inplace_api_processing(self, var): + return paddle.bitwise_left_shift(var, self.y, self.is_arithmetic) + + def test_broadcast_error(self): + broadcast_input = paddle.randn([4, 5]) + with self.assertRaises(ValueError): + self.inplace_api_processing(broadcast_input) + + +class TestDygraphInplaceBitwiseRightShift_arithmetic( + TestDygraphInplaceLogicAnd +): + def init_data(self): + self.input_var_numpy = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.input_var_numpy = paddle.to_tensor(self.input_var_numpy) + self.dtype = "int32" + self.y = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.y = paddle.to_tensor(self.y) + self.is_arithmetic = True + + def inplace_api_processing(self, var): + return paddle.bitwise_right_shift_(var, self.y, self.is_arithmetic) + + def non_inplace_api_processing(self, var): + return paddle.bitwise_right_shift_(var, self.y, self.is_arithmetic) + + def test_broadcast_error(self): + broadcast_input = paddle.randn([4, 5]) + with self.assertRaises(ValueError): + self.inplace_api_processing(broadcast_input) + + +class TestDygraphInplaceBitwiseLeftShift_logic(TestDygraphInplaceLogicAnd): + def init_data(self): + self.input_var_numpy = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.input_var_numpy = paddle.to_tensor(self.input_var_numpy) + self.dtype = "int32" + self.y = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.y = paddle.to_tensor(self.y) + self.is_arithmetic = False + + def inplace_api_processing(self, var): + return paddle.bitwise_left_shift_(var, self.y, self.is_arithmetic) + + def non_inplace_api_processing(self, var): + return paddle.bitwise_left_shift(var, self.y, self.is_arithmetic) + + def test_broadcast_error(self): + broadcast_input = paddle.randn([4, 5]) + with self.assertRaises(ValueError): + self.inplace_api_processing(broadcast_input) + + +class TestDygraphInplaceBitwiseRightShift_logic(TestDygraphInplaceLogicAnd): + def init_data(self): + self.input_var_numpy = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.input_var_numpy = paddle.to_tensor(self.input_var_numpy) + self.dtype = "int32" + self.y = np.random.randint( + low=-(2**31), high=2**31, size=[3, 4, 5], dtype="int32" + ) + self.y = paddle.to_tensor(self.y) + self.is_arithmetic = False + + def inplace_api_processing(self, var): + return paddle.bitwise_right_shift_(var, self.y, self.is_arithmetic) + + def non_inplace_api_processing(self, var): + return paddle.bitwise_right_shift_(var, self.y, self.is_arithmetic) + + def test_broadcast_error(self): + broadcast_input = paddle.randn([4, 5]) + with self.assertRaises(ValueError): + self.inplace_api_processing(broadcast_input) + + class TestDygraphInplaceIndexFill(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.random.random((20, 40))