Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure inputs live on CPU for ctc decoder #2289

Closed
wants to merge 9 commits into from

Conversation

xiaohui-zhang
Copy link
Contributor

@xiaohui-zhang xiaohui-zhang commented Mar 24, 2022

Addressing the issue #2274:
Raise Runtime errors when the input tensors to the CTC decoder are GPU tensors since the CTC decoder only runs on CPU. Also update the data type check to use "raise" rather than "assert".


Pull Request resolved: #2289
GitHub Author: xiaohui-zhang xiaohuizhang@fb.com

B, T, N = emissions.size()
if lengths is None:
lengths = torch.full((B,), T)

assert not emissions.is_cuda
assert not lengths.is_cuda
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic it self looks good. Couple of suggestions.

  1. For a better UX, could you provide an error message that tell users what to do (as opposed to what was wrong)?
  2. Please perform the input validation as soon as possible, before any operation is performed. Otherwise the operation before validation could be wasteful
  3. This is not a written rule, but assert is more fore internal assertion. (although it is used in L140, which I guess was missed at a review time) So can replace it with if <condition>: raise <something>?

@@ -129,20 +129,24 @@ def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor]

Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model
probability distribution over labels; output of acoustic model. It must lives on CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
probability distribution over labels; output of acoustic model. It must lives on CPU.
probability distribution over labels; output of acoustic model. It must live on CPU.

lengths (Tensor or None, optional): tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch
in time axis of the output Tensor in each batch. It must lives on CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
in time axis of the output Tensor in each batch. It must lives on CPU.
in time axis of the output Tensor in each batch. It must live on CPU.

@xiaohui-zhang
Copy link
Contributor Author

thanks @carolineechen and @mthrok for the thorough review of my first PR!

raise ValueError('emissions must be float32.')

if emissions.is_cuda:
raise Exception('emissions must live on CPU.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be better to raise a more specific exception, e.g. RuntimeError at min

@@ -129,20 +129,28 @@ def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor]

Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model
probability distribution over labels; output of acoustic model. It must live on CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps just "CPU tensor of shape (batch, frame, num_tokens) storing sequences of probability distribution over labels; output of acoustic model."? sounds a little more concrete than "live on CPU"

and so forth below

B, T, N = emissions.size()
if lengths is None:
lengths = torch.full((B,), T)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove extra whitespace

raise ValueError("emissions must be float32.")

if emissions.is_cuda:
raise RuntimeError("emissions must live on CPU.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar nit as above

Suggested change
raise RuntimeError("emissions must live on CPU.")
raise RuntimeError("emissions must be a CPU tensor.")

raise RuntimeError("emissions must live on CPU.")

if lengths is not None and lengths.is_cuda:
raise RuntimeError("lengths must live on CPU.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Suggested change
raise RuntimeError("lengths must live on CPU.")
raise RuntimeError("lengths must be a CPU tensor.")

@facebook-github-bot
Copy link
Contributor

@xiaohui-zhang has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@github-actions
Copy link

Hey @xiaohui-zhang.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

xiaohui-zhang added a commit to xiaohui-zhang/audio that referenced this pull request May 4, 2022
Summary:
Addressing the issue pytorch#2274:
Raise Runtime errors when the input tensors to the CTC decoder are GPU tensors since the CTC decoder only runs on CPU. Also update the data type check to use "raise" rather than "assert".

 ---
Pull Request resolved: pytorch#2289

Reviewed By: mthrok

Differential Revision: D35255630

Pulled By: xiaohui-zhang

fbshipit-source-id: d6c6e88d9ad4b9690bb741557fa9a9504e60872e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants