Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training code for hybrid-autoregressive inference model #10841

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -84,8 +83,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -382,15 +380,20 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking
torch.where(mask.unsqueeze(-1), src_states[0], dst_states[0], out=dst_states[0])

@classmethod
def batch_replace_states_all(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor],
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -591,8 +594,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -601,8 +603,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -1018,19 +1019,19 @@ def batch_score_hypothesis(

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
"""
Create batch of decoder states.
Create batch of decoder states.

Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])

decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]

Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
# LSTM has 2 states
new_states = [[] for _ in range(len(decoder_states[0]))]
for layer in range(self.pred_rnn_layers):
Expand Down Expand Up @@ -1109,7 +1110,9 @@ def batch_replace_states_mask(

@classmethod
def batch_replace_states_all(
cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor],
cls,
src_states: Tuple[torch.Tensor, torch.Tensor],
dst_states: Tuple[torch.Tensor, torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -1249,12 +1252,15 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin)

fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the
sub-batches. Should be any value below the actual batch size per GPU.
masking_prob: Optional float, indicating the probability of masking out decoder output in HAINAN
(Hybrid Autoregressive Inference Transducer) model, described in https://arxiv.org/pdf/2410.02597
Default to -1.0, which runs standard Joint network computation; if > 0, then masking out decoder output
with the specified probability.
"""

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
Expand All @@ -1266,8 +1272,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
if not self._fuse_loss_wer:
return {
"outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()),
Expand Down Expand Up @@ -1313,6 +1318,7 @@ def __init__(
fuse_loss_wer: bool = False,
fused_batch_size: Optional[int] = None,
experimental_fuse_loss_wer: Any = None,
masking_prob: float = -1.0,
):
super().__init__()

Expand All @@ -1322,6 +1328,10 @@ def __init__(
self._num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank

self.masking_prob = masking_prob
if self.masking_prob > 0.0:
assert self.masking_prob < 1.0, "masking_prob must be between 0 and 1"

if experimental_fuse_loss_wer is not None:
# Override fuse_loss_wer from deprecated argument
fuse_loss_wer = experimental_fuse_loss_wer
Expand Down Expand Up @@ -1578,6 +1588,13 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tens
"""
f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)

if self.training and self.masking_prob > 0:
[B, _, U, _] = g.shape
rand = torch.rand([B, 1, U, 1]).to(g.device)
rand = torch.gt(rand, self.masking_prob)
g = g * rand

inp = f + g # [B, T, U, H]

del f, g
Expand Down Expand Up @@ -2047,7 +2064,11 @@ def forward(
return losses, wer, wer_num, wer_denom

def sampled_joint(
self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor,
self,
f: torch.Tensor,
g: torch.Tensor,
transcript: torch.Tensor,
transcript_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Compute the sampled joint step of the network.
Expand Down
Loading