From 7071e9e40e3ac35bea499e2ac17b8521991580f3 Mon Sep 17 00:00:00 2001 From: guo-ran <360112263@qq.com> Date: Fri, 11 Feb 2022 14:37:45 +0800 Subject: [PATCH] fix bce loss half grad functor --- oneflow/user/kernels/binary_cross_entropy_kernel.cu | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/oneflow/user/kernels/binary_cross_entropy_kernel.cu b/oneflow/user/kernels/binary_cross_entropy_kernel.cu index fd1b434d3a5..2262858da20 100644 --- a/oneflow/user/kernels/binary_cross_entropy_kernel.cu +++ b/oneflow/user/kernels/binary_cross_entropy_kernel.cu @@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor { template<> struct BinaryCrossEntropyGradFunctor { - half eps_; - half one_; - BinaryCrossEntropyGradFunctor() : eps_(__float2half(1e-12)), one_(__float2half(1.f)) {} + BinaryCrossEntropyGradFunctor float_functor; + BinaryCrossEntropyGradFunctor() {} __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const { - half divisor = (one_ - input_val) * input_val; - if (divisor < eps_) { divisor = eps_; } - return dy_val * (input_val - target_val) / divisor; + return __float2half( + float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val))); } __device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, half weight_val) const { - return (*this)(input_val, target_val, dy_val) * weight_val; + return __float2half(float_functor(__half2float(input_val), __half2float(target_val), + __half2float(dy_val), __half2float(weight_val))); } };