Skip to content

Commit

Permalink
Cherry pick the support of bfloat16 for several operators. (#52608)
Browse files Browse the repository at this point in the history
* Register exp/expm1/logit bf16 activation op kernels (#48702)

* register more bf16 ops

* update to register coresponding backward ops

* Addition of bf16 type support for Compare OP  (#46413)

* first commit

* clarify the quotes

* change code style format

* support bfloat16

* add bfloat16 support for more ops (#48272)

* [Bfloat16]register bfloat16 datatype for squared l2 norm (#50908)

* Sync the pull request #51903.

* Add some header files back.

* modify cmake file for cuda11.8 compile (#49020)

* modify cmake file for cuda11.8 compile

* add op_library(fused_embedding_eltwise_layernorm_op DEPS bert_encoder_functor)

* Fix compling error.

* Cherry-pick pull request #51396.

---------

Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
Co-authored-by: limingshu <61349199+JamesLim-sy@users.noreply.github.com>
Co-authored-by: Shaojie WANG <wsjmessi@163.com>
Co-authored-by: zqw_1997 <118182234+zhengqiwen1997@users.noreply.github.com>
  • Loading branch information
5 people authored Apr 9, 2023
1 parent 73473ac commit 95c3d61
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 32 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst,
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}

__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst,
const plat::bfloat16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}

__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
}
Expand All @@ -76,6 +81,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}

__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}

__device__ __inline__ void load_zero_vector_upper_tri(float* dst) {
*(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
}
Expand Down Expand Up @@ -595,8 +604,11 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext,
plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>);
5 changes: 5 additions & 0 deletions paddle/fluid/operators/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "math.h" // NOLINT
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/hostdevice.h"

Expand All @@ -33,6 +34,10 @@ inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
return static_cast<platform::float16>(::logf(static_cast<float>(x)));
}

inline HOSTDEVICE phi::dtype::bfloat16 real_log(phi::dtype::bfloat16 x) {
return static_cast<phi::dtype::bfloat16>(::logf(static_cast<float>(x)));
}

inline HOSTDEVICE float real_log(float x) { return ::logf(x); }

inline HOSTDEVICE double real_log(double x) { return ::log(x); }
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/operators/math/cross_entropy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/cross_entropy.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"

Expand Down Expand Up @@ -152,7 +154,10 @@ void CrossEntropyFunctor<DeviceContext, T>::operator()(

template class CrossEntropyFunctor<phi::GPUContext, float>;
template class CrossEntropyFunctor<phi::GPUContext, double>;
template class CrossEntropyFunctor<phi::GPUContext, platform::float16>;
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::float16>;
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(8, 1, 0)
template class CrossEntropyFunctor<phi::GPUContext, phi::dtype::bfloat16>;
#endif

} // namespace math
} // namespace operators
Expand Down
33 changes: 25 additions & 8 deletions paddle/fluid/operators/math/cross_entropy.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h"

namespace paddle {
Expand Down Expand Up @@ -46,14 +47,30 @@ struct TolerableValue {
// Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping.
template <>
struct TolerableValue<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& x) const {
if (platform::isfinite(x))
struct TolerableValue<phi::dtype::float16> {
HOSTDEVICE phi::dtype::float16 operator()(
const phi::dtype::float16& x) const {
if (phi::dtype::isfinite(x)) {
return x;
else if (x > static_cast<platform::float16>(0))
return std::numeric_limits<platform::float16>::max();
else
return std::numeric_limits<platform::float16>::min();
} else if (x > static_cast<phi::dtype::float16>(0)) {
return std::numeric_limits<phi::dtype::float16>::max();
} else {
return std::numeric_limits<phi::dtype::float16>::min();
}
}
};

template <>
struct TolerableValue<phi::dtype::bfloat16> {
HOSTDEVICE phi::dtype::bfloat16 operator()(
const phi::dtype::bfloat16& x) const {
if (phi::dtype::isfinite(x)) {
return x;
} else if (x > static_cast<phi::dtype::bfloat16>(0)) {
return std::numeric_limits<phi::dtype::bfloat16>::max();
} else {
return std::numeric_limits<phi::dtype::bfloat16>::min();
}
}
};

Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/full_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ PD_REGISTER_KERNEL(full_batch_size_like,
int,
int64_t,
bool,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ PD_REGISTER_KERNEL(exp_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
Expand All @@ -385,15 +386,17 @@ PD_REGISTER_KERNEL(expm1_grad,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(logit_grad,
GPU,
ALL_LAYOUT,
phi::LogitGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(square_grad,
GPU,
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,24 @@ PD_REGISTER_KERNEL(exp,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(expm1,
GPU,
ALL_LAYOUT,
phi::Expm1Kernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit,
GPU,
ALL_LAYOUT,
phi::LogitKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square,
GPU,
ALL_LAYOUT,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ PD_REGISTER_KERNEL(arg_min,
ALL_LAYOUT,
phi::ArgMinKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int32_t,
Expand All @@ -267,6 +268,7 @@ PD_REGISTER_KERNEL(arg_max,
ALL_LAYOUT,
phi::ArgMaxKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int32_t,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,31 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,

} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
#endif
17 changes: 14 additions & 3 deletions paddle/phi/kernels/gpu/cross_entropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
input -= offset;
size += offset;
if (tid >= offset) {
val = reducer(val, input[tid]);
val = reducer(val, static_cast<AccT>(input[tid]));
}
size -= blockDim.x;
input += blockDim.x;
Expand All @@ -268,14 +268,14 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,

#pragma unroll
for (int i = 0; i < VecSize; ++i) {
val = reducer(val, ins[i]);
val = reducer(val, static_cast<AccT>(ins[i]));
}
}

// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
val = reducer(val, input[tid]);
val = reducer(val, static_cast<AccT>(input[tid]));
}
return val;
}
Expand Down Expand Up @@ -1470,6 +1470,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(cross_entropy_with_softmax,
GPU,
ALL_LAYOUT,
Expand All @@ -1478,3 +1488,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax,
double,
phi::dtype::float16) {}
#endif
#endif
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
Expand Down Expand Up @@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad,
double,
int64_t,
int,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/gather_nd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
Expand Down Expand Up @@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd,
int,
int16_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ PD_REGISTER_KERNEL(index_sample_grad,
GPU,
ALL_LAYOUT,
phi::IndexSampleGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/index_sample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ PD_REGISTER_KERNEL(index_sample,
GPU,
ALL_LAYOUT,
phi::IndexSampleKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/tril_triu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
Loading

0 comments on commit 95c3d61

Please sign in to comment.