-
Notifications
You must be signed in to change notification settings - Fork 690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bce loss half grad functor #7476
Changes from 2 commits
7071e9e
6c5057b
5907f79
ff2a40a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor { | |
|
||
template<> | ||
struct BinaryCrossEntropyGradFunctor<half> { | ||
half eps_; | ||
half one_; | ||
BinaryCrossEntropyGradFunctor() : eps_(__float2half(1e-12)), one_(__float2half(1.f)) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里1e-12超出了half的表示范围 |
||
BinaryCrossEntropyGradFunctor<float> float_functor; | ||
BinaryCrossEntropyGradFunctor() {} | ||
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const { | ||
half divisor = (one_ - input_val) * input_val; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当divisor非常小时可能会出现0 |
||
if (divisor < eps_) { divisor = eps_; } | ||
return dy_val * (input_val - target_val) / divisor; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当divisor为0时会出现nan。 |
||
return __float2half( | ||
float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val))); | ||
} | ||
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val, | ||
half weight_val) const { | ||
return (*this)(input_val, target_val, dy_val) * weight_val; | ||
return __float2half(float_functor(__half2float(input_val), __half2float(target_val), | ||
__half2float(dy_val), __half2float(weight_val))); | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BinaryCrossEntropyGradFunctor的half的内部计算应该转成float进行计算