Skip to content

Commit

Permalink
Dgalvez/cuda graphs greedy rnnt inference squash (NVIDIA#8191)
Browse files Browse the repository at this point in the history
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>
  • Loading branch information
galv authored and zpx01 committed Mar 8, 2024
1 parent 8de5113 commit a5a680d
Show file tree
Hide file tree
Showing 6 changed files with 707 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
"""Replace states in dst_states with states from src_states using the mask, in a way that does not synchronize with the CPU"""
raise NotImplementedError()

def batch_split_states(self, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]:
Expand Down
Loading

0 comments on commit a5a680d

Please sign in to comment.