Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement ctc loss function #1049

Closed
wants to merge 20 commits into from
Closed

implement ctc loss function #1049

wants to merge 20 commits into from

Conversation

wcshds
Copy link
Contributor

@wcshds wcshds commented Dec 4, 2023

I need ctc loss function in CRNN model. I tried to implement it based on PyTorch implementation, but the results obtained after calling forward() are somewhat different from PyTorch's.

I don't know what went wrong, I'd appreciate it if someone could tell me.

Reference

Checklist

  • Confirmed that run-checks all script has been executed.

@louisfd
Copy link
Member

louisfd commented Dec 4, 2023

Hi,
I can take a look at it later today

@wcshds
Copy link
Contributor Author

wcshds commented Dec 5, 2023

I believe the result is now the same as PyTorch. But the performance of this implementation seems to be less than ideal.

@wcshds
Copy link
Contributor Author

wcshds commented Dec 8, 2023

The implementation doesn't work on the NdArray backend because of #1053. It also doesn't work on the LibTorch backend because of #1055.

I believe the current performance bottleneck lies in creating the one-hot. This is because the repeat() method is very slow on the Wgpu backend.

fn one_hot<B: Backend>(tensor: Tensor<B, 2, Int>, num_classes: usize) -> Tensor<B, 3> {
let device = tensor.device();
let shape = tensor.dims();
let labels: Tensor<B, 3, Int> = tensor.unsqueeze_dim(2).repeat(2, num_classes);
let indices = Tensor::<B, 1, Int>::arange_device(0..num_classes, &device)
.reshape([1, 1, num_classes])
.repeat(1, shape[1])
.repeat(0, shape[0]);
labels.equal(indices).float()
}

@louisfd
Copy link
Member

louisfd commented Dec 12, 2023

Hi @wcshds
I haven't had time like I thought last week and then I was abroad for several days. I'm sorry I said I was gonna look at it last week, but I certainly haven't forgotten you! Glad to see you continued working on it since then. I will definitely take a look real soon

@wcshds
Copy link
Contributor Author

wcshds commented Dec 12, 2023

@louisfd Thank you! The current implementation still significantly consumes graphics memory. I believe that separately calculating the alpha values for blanks and letters can significantly reduce the graphics memory usage, but I don't know how to implement it.

@louisfd louisfd mentioned this pull request Dec 14, 2023
2 tasks
@louisfd
Copy link
Member

louisfd commented Dec 14, 2023

@wcshds
I took your word that repeat was the bottleneck in wgpu. This made a lot of sense because we relied on the default implementation which launches as many slice_assign kernels as there are repetitions. For large times argument this is awful.
I wrote a repeat kernel so that only one kernel is launched instead of times: #1068

Please tell me if this is better now

@wcshds
Copy link
Contributor Author

wcshds commented Dec 14, 2023

@louisfd Thank you! Now repeat() is much faster.

@wcshds
Copy link
Contributor Author

wcshds commented Dec 14, 2023

I tried to use this implementation of ctc loss in the CRNN model, but after the first iteration loss became NaN. I don't know what went wrong. wcshds/crnn-cjk

@antimora
Copy link
Collaborator

Just noticed 1-e15 magic number. Please refactor to a constant and explain how this number is derived. It would also be preferable if float number precision independent (we use half and full precisions)

@wcshds
Copy link
Contributor Author

wcshds commented Dec 27, 2023

@antimora I just need a small value to prevent log(0), so now I think it may not be necessary to use 1e-15; 1e-5 should be small enough.
However, I think CTC Loss may not be suitable for half precision because I previously attempted to use mixed precision training in PyTorch, but PyTorch's CTC Loss does not support fp16. [CTC Loss] CTC Loss not support float16? Perhaps I need to explore the use of half precision training in future practices to see if CTC Loss can work with it.

@@ -7,6 +7,8 @@ use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Te
use super::Reduction;

const NEG_INF: f32 = -1e5;
// a small value used to prevent the occurrence of log(0)
const DELTA: f32 = -1e-5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to have this number as negative? The literal number you has is positive. In an unlikely event, (l1 - m.clone()).exp() + (l2 - m.clone()).exp() expression could be equal to abs(DELTA) which would still lead to log(0) situation.

Additionally, I suggest we use [https://doc.rust-lang.org/std/primitive.f32.html#associatedconstant.EPSILON](f32's EPSILON) or [f16's EPSILON]https://docs.rs/tract-core/latest/tract_core/prelude/struct.f16.html#associatedconstant.EPSILON constants depending what on Backend's precision settings. @nathanielsimard or @louisfd can suggest on how we can extract this. -1e-5 seems a rather big number for f16 or f32. (probably it may not work for f16 because its epsilon is 4.88e-04. we need to double check it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, it's a typo. DELTA should be positive.

1e-5 can ensure that the results of the loss are accurate to three decimal places, but 4.88e-4 is a bit large. Perhaps CTC Loss is indeed not suitable for the use of half-precision training.

@wcshds wcshds marked this pull request as draft January 12, 2024 10:31
@antimora antimora added the feature The feature request label Jan 31, 2024
@antimora antimora added the stale The issue or pr has been open for too long label Feb 24, 2024
@antimora
Copy link
Collaborator

Closing this ticket and linking to an issue ticket so someone else can pick up: #1536

@antimora antimora closed this Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature The feature request stale The issue or pr has been open for too long
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants