-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make sure inputs live on CPU for ctc decoder #2289
Conversation
B, T, N = emissions.size() | ||
if lengths is None: | ||
lengths = torch.full((B,), T) | ||
|
||
assert not emissions.is_cuda | ||
assert not lengths.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic it self looks good. Couple of suggestions.
- For a better UX, could you provide an error message that tell users what to do (as opposed to what was wrong)?
- Please perform the input validation as soon as possible, before any operation is performed. Otherwise the operation before validation could be wasteful
- This is not a written rule, but
assert
is more fore internal assertion. (although it is used in L140, which I guess was missed at a review time) So can replace it withif <condition>: raise <something>
?
@@ -129,20 +129,24 @@ def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] | |||
|
|||
Args: | |||
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of | |||
probability distribution over labels; output of acoustic model | |||
probability distribution over labels; output of acoustic model. It must lives on CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probability distribution over labels; output of acoustic model. It must lives on CPU. | |
probability distribution over labels; output of acoustic model. It must live on CPU. |
lengths (Tensor or None, optional): tensor of shape `(batch, )` storing the valid length of | ||
in time axis of the output Tensor in each batch | ||
in time axis of the output Tensor in each batch. It must lives on CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in time axis of the output Tensor in each batch. It must lives on CPU. | |
in time axis of the output Tensor in each batch. It must live on CPU. |
thanks @carolineechen and @mthrok for the thorough review of my first PR! |
raise ValueError('emissions must be float32.') | ||
|
||
if emissions.is_cuda: | ||
raise Exception('emissions must live on CPU.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'd be better to raise a more specific exception, e.g. RuntimeError
at min
@@ -129,20 +129,28 @@ def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] | |||
|
|||
Args: | |||
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of | |||
probability distribution over labels; output of acoustic model | |||
probability distribution over labels; output of acoustic model. It must live on CPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps just "CPU tensor of shape (batch, frame, num_tokens)
storing sequences of probability distribution over labels; output of acoustic model."? sounds a little more concrete than "live on CPU"
and so forth below
B, T, N = emissions.size() | ||
if lengths is None: | ||
lengths = torch.full((B,), T) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove extra whitespace
raise ValueError("emissions must be float32.") | ||
|
||
if emissions.is_cuda: | ||
raise RuntimeError("emissions must live on CPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar nit as above
raise RuntimeError("emissions must live on CPU.") | |
raise RuntimeError("emissions must be a CPU tensor.") |
raise RuntimeError("emissions must live on CPU.") | ||
|
||
if lengths is not None and lengths.is_cuda: | ||
raise RuntimeError("lengths must live on CPU.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here
raise RuntimeError("lengths must live on CPU.") | |
raise RuntimeError("lengths must be a CPU tensor.") |
@xiaohui-zhang has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hey @xiaohui-zhang. |
Summary: Addressing the issue pytorch#2274: Raise Runtime errors when the input tensors to the CTC decoder are GPU tensors since the CTC decoder only runs on CPU. Also update the data type check to use "raise" rather than "assert". --- Pull Request resolved: pytorch#2289 Reviewed By: mthrok Differential Revision: D35255630 Pulled By: xiaohui-zhang fbshipit-source-id: d6c6e88d9ad4b9690bb741557fa9a9504e60872e
Addressing the issue #2274:
Raise Runtime errors when the input tensors to the CTC decoder are GPU tensors since the CTC decoder only runs on CPU. Also update the data type check to use "raise" rather than "assert".
Pull Request resolved: #2289
GitHub Author: xiaohui-zhang xiaohuizhang@fb.com