Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing torch.ctc err #1485

Merged
merged 2 commits into from
Feb 2, 2024
Merged

Fixing torch.ctc err #1485

merged 2 commits into from
Feb 2, 2024

Conversation

teowenshen
Copy link
Contributor

@teowenshen teowenshen commented Feb 1, 2024

F.ctc_loss throws error when batch size N is 1. Using torch.long as recommended by the documentation fixes this error.

import torch
N = 1
T = 50
S = 31
V = 20
log_probs = torch.randn(T, N, V, device="cuda:0").log_softmax(2)
targets = torch.randint(1, 20, (S,), device="cuda:0", dtype=torch.int32)
input_lengths = torch.full((N,), T, device="cuda:0", dtype=torch.int32)
target_lengths = torch.tensor([S,], device="cuda:0", dtype=torch.int32)

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)

throws:

RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for cudnn_ctc_loss)

@csukuangfj
Copy link
Collaborator

The error is from the following line
https://github.com/pytorch/pytorch/blob/2b48891e62e5c4b57c8cac92cee5eb71228a203a/aten/src/ATen/native/cudnn/LossCTC.cpp#L146

  checkBackend(c, {*targets}, Backend::CPU);

You can see that it expects targets on CPU.


As for the lengths, https://github.com/pytorch/pytorch/blob/2b48891e62e5c4b57c8cac92cee5eb71228a203a/aten/src/ATen/native/cudnn/LossCTC.cpp#L221

  Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
  Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();

They are converted to kLong and moved to CPU.

So the only thing that needs to be changed is to move targets to CPU.

@yfyeung
Copy link
Collaborator

yfyeung commented Feb 1, 2024

Disable cudnn also works for me.

        with torch.backends.cudnn.flags(enabled=False):
            ctc_loss = torch.nn.functional.ctc_loss(
                log_probs=ctc_output.permute(1, 0, 2),  # (T, N, C)
                targets=targets,
                input_lengths=encoder_out_lens,
                target_lengths=target_lengths,
                reduction="sum",
            )

@teowenshen
Copy link
Contributor Author

I have followed @csukuangfj 's advice and moved the targets and lengths to CPU. There were some comments that say cudnn is the only deterministic implementation in PyTorch.

Surprisingly, in my case the error wasn't thrown until when I had a last batch in the validation loop that had only 1 sample. But it doesn't seem to be related to the batch size after all.

@csukuangfj
Copy link
Collaborator

Thanks!

@csukuangfj csukuangfj merged commit b9e6327 into k2-fsa:master Feb 2, 2024
53 of 54 checks passed
@teowenshen teowenshen deleted the ctc_error branch February 8, 2024 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants