-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement ctc loss function #1049
Conversation
Hi, |
I believe the result is now the same as PyTorch. But the performance of this implementation seems to be less than ideal. |
The implementation doesn't work on the NdArray backend because of #1053. It also doesn't work on the LibTorch backend because of #1055. I believe the current performance bottleneck lies in creating the one-hot. This is because the burn/burn-core/src/nn/loss/ctc.rs Lines 330 to 341 in 7121852
|
Hi @wcshds |
@louisfd Thank you! The current implementation still significantly consumes graphics memory. I believe that separately calculating the alpha values for blanks and letters can significantly reduce the graphics memory usage, but I don't know how to implement it. |
@wcshds Please tell me if this is better now |
@louisfd Thank you! Now |
I tried to use this implementation of ctc loss in the CRNN model, but after the first iteration loss became NaN. I don't know what went wrong. wcshds/crnn-cjk |
Just noticed 1-e15 magic number. Please refactor to a constant and explain how this number is derived. It would also be preferable if float number precision independent (we use half and full precisions) |
@antimora I just need a small value to prevent log(0), so now I think it may not be necessary to use 1e-15; 1e-5 should be small enough. |
burn-core/src/nn/loss/ctc.rs
Outdated
@@ -7,6 +7,8 @@ use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Te | |||
use super::Reduction; | |||
|
|||
const NEG_INF: f32 = -1e5; | |||
// a small value used to prevent the occurrence of log(0) | |||
const DELTA: f32 = -1e-5; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to have this number as negative? The literal number you has is positive. In an unlikely event, (l1 - m.clone()).exp() + (l2 - m.clone()).exp()
expression could be equal to abs(DELTA)
which would still lead to log(0) situation.
Additionally, I suggest we use [https://doc.rust-lang.org/std/primitive.f32.html#associatedconstant.EPSILON](f32's EPSILON) or [f16's EPSILON]https://docs.rs/tract-core/latest/tract_core/prelude/struct.f16.html#associatedconstant.EPSILON constants depending what on Backend's precision settings. @nathanielsimard or @louisfd can suggest on how we can extract this. -1e-5
seems a rather big number for f16 or f32. (probably it may not work for f16 because its epsilon is 4.88e-04
. we need to double check it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry, it's a typo. DELTA
should be positive.
1e-5
can ensure that the results of the loss are accurate to three decimal places, but 4.88e-4
is a bit large. Perhaps CTC Loss is indeed not suitable for the use of half-precision training.
Closing this ticket and linking to an issue ticket so someone else can pick up: #1536 |
I need ctc loss function in CRNN model. I tried to implement it based on PyTorch implementation, but the results obtained after calling forward() are somewhat different from PyTorch's.
I don't know what went wrong, I'd appreciate it if someone could tell me.
Reference
Checklist
run-checks all
script has been executed.