Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bce loss half grad functor #7476

Merged
merged 4 commits into from
Feb 11, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions oneflow/user/kernels/binary_cross_entropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,16 @@ struct BinaryCrossEntropyGradFunctor {

template<>
struct BinaryCrossEntropyGradFunctor<half> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinaryCrossEntropyGradFunctor的half的内部计算应该转成float进行计算

half eps_;
half one_;
BinaryCrossEntropyGradFunctor() : eps_(__float2half(1e-12)), one_(__float2half(1.f)) {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里1e-12超出了half的表示范围

BinaryCrossEntropyGradFunctor<float> float_functor;
BinaryCrossEntropyGradFunctor() {}
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val) const {
half divisor = (one_ - input_val) * input_val;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当divisor非常小时可能会出现0

if (divisor < eps_) { divisor = eps_; }
return dy_val * (input_val - target_val) / divisor;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当divisor为0时会出现nan。

return __float2half(
float_functor(__half2float(input_val), __half2float(target_val), __half2float(dy_val)));
}
__device__ __forceinline__ half operator()(half input_val, half target_val, half dy_val,
half weight_val) const {
return (*this)(input_val, target_val, dy_val) * weight_val;
return __float2half(float_functor(__half2float(input_val), __half2float(target_val),
__half2float(dy_val), __half2float(weight_val)));
}
};

Expand Down