-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Label-Looping algorithm for RNN-T decoding by default #8831
Conversation
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Not ready yet |
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
jenkins |
state = [torch.ones([batch, self.context_size], dtype=torch.long, device=y.device) * self.blank_idx] | ||
# state contains context_size - 1 elements for each utterance in batch, | ||
# consistent with the state returned from StatelessNet.forward | ||
state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hainan-xv, please, confirm that I broke nothing when fixing state for the Stateless decoder.
We need the state with the constant size (to allow replacements when we found the end of utterance), and forward
returns the state of size [batch_size, context_size - 1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, you can also use torch.full instead of torch.ones followed by multiplication. No need to change it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks great, nice work !
@@ -73,7 +73,7 @@ def predict( | |||
return ( | |||
output, | |||
[ | |||
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].exand( | |||
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].expand( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How how did this test pass with this error ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this case is redundant and never executed in decoding, but we need to implement this to conform the interface where y
is optional (see AbstractRNNTDecoder.predict
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late review with a few FYI comments. Do we test the cuda graphs implementation with stateless transducers yet?
state = [torch.ones([batch, self.context_size], dtype=torch.long, device=y.device) * self.blank_idx] | ||
# state contains context_size - 1 elements for each utterance in batch, | ||
# consistent with the state returned from StatelessNet.forward | ||
state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, you can also use torch.full instead of torch.ones followed by multiplication. No need to change it though.
return [ | ||
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :] | ||
.expand([1, batch_size, -1]) | ||
.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.repeat or torch.repeat_interleave is probably the better way to do it than expand followed by clone.
Thanks for the comments, @galv! I will make fixes in next PRs. |
* Use Label-Looping algorithm for RNN-T decoding by default * Fix loop labels + stateless decoding --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
* Use Label-Looping algorithm for RNN-T decoding by default * Fix loop labels + stateless decoding --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> Signed-off-by: jxin <jxin@nvidia.com>
* Use Label-Looping algorithm for RNN-T decoding by default * Fix loop labels + stateless decoding --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com> Signed-off-by: Ao Tang <aot@nvidia.com>
* Use Label-Looping algorithm for RNN-T decoding by default * Fix loop labels + stateless decoding --------- Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
What does this PR do ?
Enable Label-Looping algorithm introduced in #8286 and #7926 (
loop_labels=True
) by default for RNN-T greedy decoding.Collection: [ASR]
Changelog
loop_labels=true
)Usage
Label-Looping algorithm is used by default now for batched greedy decoding.
For Frame-Looping algorithm one can use:
Jenkins CI
To run Jenkins, a NeMo User with write access must comment
jenkins
on the PR.Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information