Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dgalvez/cuda graphs greedy rnnt inference squash (NVIDIA#8191)
* Speed up RNN-T greedy decoding with cuda graphs This uses CUDA 12.3's conditional node support. Initialize cuda tensors lazily on first call of __call__ instead of __init__. We don't know what device is going to be used at construction time, and we can't rely on torch.nn.Module.to() to work here. See here: NVIDIA#8436 This fixes an error "Expected all tensors to be on the same device, but found at least two devices" that happens when you call to() on your torch.nn.Module after constructing it. NVIDIA#8191 (comment) Signed-off-by: Daniel Galvez <dgalvez@nvidia.com> Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>
- Loading branch information