Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API -part #58092

Merged
merged 31 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
813c09d
test
cocoshe Dec 25, 2023
42b3972
fix
cocoshe Dec 25, 2023
198d875
fix
cocoshe Dec 25, 2023
6ab394d
test
cocoshe Dec 27, 2023
78d8fec
update
cocoshe Dec 27, 2023
5d627e1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 27, 2023
994d832
update
cocoshe Dec 27, 2023
7587891
update
cocoshe Dec 27, 2023
aac04c8
test split op
cocoshe Dec 28, 2023
99aef13
codestyle
cocoshe Dec 28, 2023
d1e942d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 28, 2023
af86d96
add register
cocoshe Dec 29, 2023
db05b52
split on win
cocoshe Dec 29, 2023
126e5a5
test bigger shape
cocoshe Dec 30, 2023
f5f51b8
add note
cocoshe Jan 2, 2024
0fd2774
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Jan 2, 2024
d28481d
fix
cocoshe Jan 2, 2024
899d1ff
Update bitwise_functors.h
cocoshe Jan 2, 2024
04c80b2
fix
cocoshe Jan 3, 2024
98b392d
Merge branch 'develop' into bitwise_shift_coco_dev
cocoshe Jan 5, 2024
60f0e92
fix doc
cocoshe Jan 5, 2024
3b1e7e4
fix doc
cocoshe Jan 5, 2024
d21eea0
refactor
cocoshe Jan 8, 2024
28e0e6a
enhence doc
cocoshe Jan 8, 2024
6214243
fix doc
cocoshe Jan 8, 2024
9d50a52
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
c74106d
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
fb5ad3c
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
b4e8edd
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
0ecabc8
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
809fb46
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@
backend : x
inplace: (x -> out)

- op : bitwise_left_shift
args : (Tensor x, Tensor y, bool is_arithmetic = true)
output : Tensor(out)
infer_meta :
func : BitwiseShiftInferMeta
kernel :
func : bitwise_left_shift
backend : x
inplace: (x -> out)

- op : bitwise_not
args : (Tensor x)
output : Tensor(out)
Expand All @@ -364,6 +374,16 @@
backend : x
inplace: (x -> out)

- op : bitwise_right_shift
args : (Tensor x, Tensor y, bool is_arithmetic = true)
output : Tensor(out)
infer_meta :
func : BitwiseShiftInferMeta
kernel :
func : bitwise_right_shift
backend : x
inplace: (x -> out)

- op : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,13 @@ void ElementwiseInferMeta(const MetaTensor& x,
return ElementwiseRawInferMeta(x, y, -1, out);
}

void BitwiseShiftInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool is_arithmetic,
MetaTensor* out) {
return ElementwiseRawInferMeta(x, y, -1, out);
}

void ElementwiseRawInferMeta(const MetaTensor& x,
const MetaTensor& y,
int axis,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta,
MetaTensor* out,
MetaConfig config = MetaConfig());

void BitwiseShiftInferMeta(const MetaTensor& x,
const MetaTensor& y,
bool is_arithmetic,
MetaTensor* out);

void EmbeddingInferMeta(const MetaTensor& x,
const MetaTensor& weight,
int64_t padding_idx,
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/kernels/bitwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,18 @@ void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseLeftShiftKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool is_arithmetic,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseRightShiftKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool is_arithmetic,
DenseTensor* out);

} // namespace phi
59 changes: 59 additions & 0 deletions paddle/phi/kernels/cpu/bitwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,45 @@ DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL

#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
bool is_arithmetic, \
DenseTensor* out) { \
auto x_dims = x.dims(); \
auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \
if (is_arithmetic) { \
funcs::Bitwise##op_type##ArithmeticFunctor<T> func; \
funcs::ElementwiseCompute< \
funcs::Bitwise##op_type##ArithmeticFunctor<T>, \
T>(dev_ctx, x, y, func, out); \
} else { \
funcs::Bitwise##op_type##LogicFunctor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##LogicFunctor<T>, \
T>(dev_ctx, x, y, func, out); \
} \
} else { \
if (is_arithmetic) { \
funcs::InverseBitwise##op_type##ArithmeticFunctor<T> inv_func; \
funcs::ElementwiseCompute< \
funcs::InverseBitwise##op_type##ArithmeticFunctor<T>, \
T>(dev_ctx, x, y, inv_func, out); \
} else { \
funcs::InverseBitwise##op_type##LogicFunctor<T> inv_func; \
funcs::ElementwiseCompute< \
funcs::InverseBitwise##op_type##LogicFunctor<T>, \
T>(dev_ctx, x, y, inv_func, out); \
} \
} \
}

DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShift)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShift)
#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE

template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -97,3 +136,23 @@ PD_REGISTER_KERNEL(bitwise_not,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift,
CPU,
ALL_LAYOUT,
phi::BitwiseLeftShiftKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift,
CPU,
ALL_LAYOUT,
phi::BitwiseRightShiftKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
159 changes: 159 additions & 0 deletions paddle/phi/kernels/funcs/bitwise_functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,164 @@ struct BitwiseNotFunctor<bool> {
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};

template <typename T>
struct BitwiseLeftShiftArithmeticFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
if (b < static_cast<T>(0)) return static_cast<T>(0);
return a << b;
}
};

template <typename T>
struct InverseBitwiseLeftShiftArithmeticFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
if (a < static_cast<T>(0)) return static_cast<T>(0);
return b << a;
}
};

template <typename T>
struct BitwiseLeftShiftLogicFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
return a << b;
}
};

template <typename T>
struct InverseBitwiseLeftShiftLogicFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
return b << a;
}
};

template <typename T>
struct BitwiseRightShiftArithmeticFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(-(a >> (sizeof(T) * 8 - 1) & 1));
return a >> b;
}
};

template <typename T>
struct InverseBitwiseRightShiftArithmeticFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(-(b >> (sizeof(T) * 8 - 1) & 1));
return b >> a;
}
};

template <>
struct BitwiseRightShiftArithmeticFunctor<uint8_t> {
HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
if (b >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
return static_cast<uint8_t>(0);
return a >> b;
}
};

template <>
struct InverseBitwiseRightShiftArithmeticFunctor<uint8_t> {
inline HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
if (a >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
return static_cast<uint8_t>(0);
return b >> a;
}
};

template <typename T>
struct BitwiseRightShiftLogicFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b >= static_cast<T>(sizeof(T) * 8) || b < static_cast<T>(0))
return static_cast<T>(0);
return a >> b;
}
};

template <typename T>
struct InverseBitwiseRightShiftLogicFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a >= static_cast<T>(sizeof(T) * 8) || a < static_cast<T>(0))
return static_cast<T>(0);
return b >> a;
}
};

template <typename T>
HOSTDEVICE T logic_shift_func(const T a, const T b) {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
T t = static_cast<T>(sizeof(T) * 8 - 1);
T mask = (((a >> t) << t) >> b) << 1;
return (a >> b) ^ mask;
}

// signed int8
template <>
struct BitwiseRightShiftLogicFunctor<int8_t> {
HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
return logic_shift_func<int8_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int8_t> {
inline HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
return logic_shift_func<int8_t>(b, a);
}
};

// signed int16
template <>
struct BitwiseRightShiftLogicFunctor<int16_t> {
HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
return logic_shift_func<int16_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int16_t> {
inline HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
return logic_shift_func<int16_t>(b, a);
}
};

// signed int32
template <>
struct BitwiseRightShiftLogicFunctor<int> {
HOSTDEVICE int operator()(const int a, const int b) const {
return logic_shift_func<int32_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
return logic_shift_func<int32_t>(b, a);
}
};

// signed int64
template <>
struct BitwiseRightShiftLogicFunctor<int64_t> {
HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
return logic_shift_func<int64_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
return logic_shift_func<int64_t>(b, a);
}
};

} // namespace funcs
} // namespace phi
43 changes: 43 additions & 0 deletions paddle/phi/kernels/kps/bitwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL

#define DEFINE_BITWISE_KERNEL_WITH_BOOL(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
bool is_arithmetic, \
DenseTensor* out) { \
dev_ctx.template Alloc<T>(out); \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
if (is_arithmetic) { \
funcs::Bitwise##op_type##ArithmeticFunctor<T> func; \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
} else { \
funcs::Bitwise##op_type##LogicFunctor<T> func; \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
} \
}

DEFINE_BITWISE_KERNEL_WITH_BOOL(LeftShift)
DEFINE_BITWISE_KERNEL_WITH_BOOL(RightShift)
#undef DEFINE_BITWISE_KERNEL_WITH_BOOL

template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -112,4 +135,24 @@ PD_REGISTER_KERNEL(bitwise_not,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift,
KPS,
ALL_LAYOUT,
phi::BitwiseLeftShiftKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift,
KPS,
ALL_LAYOUT,
phi::BitwiseRightShiftKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

#endif
8 changes: 8 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@
atan_,
atanh,
atanh_,
bitwise_left_shift,
bitwise_left_shift_,
bitwise_right_shift,
bitwise_right_shift_,
broadcast_shape,
ceil,
clip,
Expand Down Expand Up @@ -944,6 +948,10 @@
'i1e',
'polygamma',
'polygamma_',
'bitwise_left_shift',
'bitwise_left_shift_',
'bitwise_right_shift',
'bitwise_right_shift_',
'masked_fill',
'masked_fill_',
'masked_scatter',
Expand Down
Loading