From 5d48a3f27ee1db94cda622ca2423ffadfe1179cc Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 18 Sep 2024 12:57:59 +0400 Subject: [PATCH 01/13] remove stacking operations Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 218 +++++++----------- nemo/collections/asr/modules/rnnt_abstract.py | 5 +- .../parts/submodules/rnnt_beam_decoding.py | 69 +++--- 3 files changed, 113 insertions(+), 179 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 2355cfb7005b..1e5afff8bf33 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -316,21 +316,19 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device) ] return state - - def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + + def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): """ - Create batch of decoder states. + Creates a stacked decoder states to be passed to prediction network. Args: - batch_states (list): batch of decoder states - ([(B, H)]) - - decoder_states (list of list): list of decoder states - [B x ([(1, C)]] - + decoder_states (list of list of torch.Tensor): list of decoder states + [B, 1, C] + - B: Batch size. + - C: Dimensionality of the hidden state. + Returns: - batch_states (tuple): batch of decoder states - ([(B, C)]) + batch_states (list of torch.Tensor): batch of decoder states [[B x C]] """ new_state = torch.stack([s[0] for s in decoder_states]) @@ -449,87 +447,67 @@ def mask_select_states( return [states[0][mask]] def batch_score_hypothesis( - self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] - ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: + self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. Args: hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. - batch_states: List of torch.Tensor which represent the states of the RNN for this batch. - Each state is of shape [L, B, H] Returns: - Returns a tuple (b_y, b_states, lm_tokens) such that: - b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. - b_state is a list of list of RNN states, each of shape [L, B, H]. - Represented as B x List[states]. - lm_token is a list of the final integer tokens of the hypotheses in the batch. + Returns a tuple (batch_dec_out, batch_dec_states) such that: + batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses. + batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states]. """ final_batch = len(hypotheses) + if final_batch == 0: raise ValueError("No hypotheses was provided for the batch!") _p = next(self.parameters()) device = _p.device - dtype = _p.dtype tokens = [] - process = [] - done = [None for _ in range(final_batch)] + to_process = [] + final = [None for _ in range(final_batch)] # For each hypothesis, cache the last token of the sequence and the current states - for i, hyp in enumerate(hypotheses): + for final_idx, hyp in enumerate(hypotheses): sequence = tuple(hyp.y_sequence) if sequence in cache: - done[i] = cache[sequence] + final[final_idx] = cache[sequence] else: tokens.append(hyp.y_sequence[-1]) - process.append((sequence, hyp.dec_state)) + to_process.append((sequence, hyp.dec_state)) - if process: - batch = len(process) + if to_process: + batch = len(to_process) # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.initialize_state(tokens) # [B, C] - dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process]) + dec_states = self.batch_stack_states([d_state for _, d_state in to_process]) - y, dec_states = self.predict( + dec_out, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch - ) # [B, 1, H], List([L, 1, H]) - - dec_states = tuple(state.to(dtype=dtype) for state in dec_states) + ) # [B, 1, H], B x List([L, 1, H]) - # Update done states and cache shared by entire batch. - j = 0 - for i in range(final_batch): - if done[i] is None: + # Update final states and cache shared by entire batch. + processed_idx = 0 + for final_idx in range(final_batch): + if final[final_idx] is None: # Select sample's state from the batch state list - new_state = self.batch_select_state(dec_states, j) + new_state = self.batch_select_state(dec_states, processed_idx) # Cache [1, H] scores of the current y_j, and its corresponding state - done[i] = (y[j], new_state) - cache[process[j][0]] = (y[j], new_state) - - j += 1 - - # Set the incoming batch states with the new states obtained from `done`. - batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done]) - - # Create batch of all output scores - # List[1, 1, H] -> [B, 1, H] - batch_y = torch.stack([y_j for y_j, d_state in done]) - - # Extract the last tokens from all hypotheses and convert to a tensor - lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view( - final_batch - ) - - return batch_y, batch_states, lm_tokens + final[final_idx] = (dec_out[processed_idx], new_state) + cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) + processed_idx += 1 + + return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMixin): """A Recurrent Neural Network Transducer Decoder / Prediction Network (RNN-T Prediction Network). @@ -934,23 +912,19 @@ def score_hypothesis( return y, new_state, lm_token def batch_score_hypothesis( - self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] - ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: + self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. Args: hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. - batch_states: List of torch.Tensor which represent the states of the RNN for this batch. - Each state is of shape [L, B, H] Returns: - Returns a tuple (b_y, b_states, lm_tokens) such that: - b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. - b_state is a list of list of RNN states, each of shape [L, B, H]. - Represented as B x List[states]. - lm_token is a list of the final integer tokens of the hypotheses in the batch. + Returns a tuple (batch_dec_out, batch_dec_states) such that: + batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses. + batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states]. """ final_batch = len(hypotheses) @@ -959,90 +933,68 @@ def batch_score_hypothesis( _p = next(self.parameters()) device = _p.device - dtype = _p.dtype tokens = [] - process = [] - done = [None for _ in range(final_batch)] + to_process = [] + final = [None for _ in range(final_batch)] # For each hypothesis, cache the last token of the sequence and the current states - for i, hyp in enumerate(hypotheses): + for final_idx, hyp in enumerate(hypotheses): sequence = tuple(hyp.y_sequence) if sequence in cache: - done[i] = cache[sequence] + final[final_idx] = cache[sequence] else: tokens.append(hyp.y_sequence[-1]) - process.append((sequence, hyp.dec_state)) + to_process.append((sequence, hyp.dec_state)) - if process: - batch = len(process) + if to_process: + batch = len(to_process) # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.initialize_state(tokens.to(dtype=dtype)) # [L, B, H] - dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process]) + dec_states = self.batch_stack_states([d_state for _, d_state in to_process]) - y, dec_states = self.predict( + dec_out, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch - ) # [B, 1, H], List([L, 1, H]) + ) # [B, 1, H], B x List([L, 1, H]) - dec_states = tuple(state.to(dtype=dtype) for state in dec_states) - - # Update done states and cache shared by entire batch. - j = 0 - for i in range(final_batch): - if done[i] is None: + # Update final states and cache shared by entire batch. + processed_idx = 0 + for final_idx in range(final_batch): + if final[final_idx] is None: # Select sample's state from the batch state list - new_state = self.batch_select_state(dec_states, j) + new_state = self.batch_select_state(dec_states, processed_idx) # Cache [1, H] scores of the current y_j, and its corresponding state - done[i] = (y[j], new_state) - cache[process[j][0]] = (y[j], new_state) - - j += 1 - - # Set the incoming batch states with the new states obtained from `done`. - batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done]) - - # Create batch of all output scores - # List[1, 1, H] -> [B, 1, H] - batch_y = torch.stack([y_j for y_j, d_state in done]) + final[final_idx] = (dec_out[processed_idx], new_state) + cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) - # Extract the last tokens from all hypotheses and convert to a tensor - lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view( - final_batch - ) - - return batch_y, batch_states, lm_tokens + processed_idx += 1 + + return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] - def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """ - Create batch of decoder states. - - Args: - batch_states (list): batch of decoder states - ([L x (B, H)], [L x (B, H)]) - - decoder_states (list of list): list of decoder states - [B x ([L x (1, H)], [L x (1, H)])] + Creates a stacked decoder states to be passed to prediction network - Returns: - batch_states (tuple): batch of decoder states - ([L x (B, H)], [L x (B, H)]) - """ - # LSTM has 2 states - new_states = [[] for _ in range(len(decoder_states[0]))] - for layer in range(self.pred_rnn_layers): - for state_id in range(len(decoder_states[0])): - # batch_states[state_id][layer] = torch.stack([s[state_id][layer] for s in decoder_states]) - new_state_for_layer = torch.stack([s[state_id][layer] for s in decoder_states]) - new_states[state_id].append(new_state_for_layer) - - for state_id in range(len(decoder_states[0])): - new_states[state_id] = torch.stack([state for state in new_states[state_id]]) + Args: + decoder_states (list of list of list of torch.Tensor): list of decoder states + [B, L, 1, H] + - B: Batch size. + - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). + - H: Dimensionality of the hidden state. + + Returns: + batch_states (list of torch.Tensor): batch of decoder states + [L x torch.Tensor[1 x B x H] + """ + # stack decoder states into tensor of shape [B x L x 1 x H] + # permute to the target shape [L x 1 x B x H] + stacked_states = torch.stack([torch.stack(decoder_state) for decoder_state in decoder_states]) + permuted_states = stacked_states.permute(1, 2, 0, 3) - return new_states + return list(permuted_states.contiguous()) def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: """Get decoder state from batch of states, for given id. @@ -1057,15 +1009,11 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List (tuple): decoder states for given id ([L x (1, H)], [L x (1, H)]) """ + # print("###", len(batch_states), batch_states[0].shape, self.pred_rnn_layers) if batch_states is not None: - state_list = [] - for state_id in range(len(batch_states)): - states = [batch_states[state_id][layer][idx] for layer in range(self.pred_rnn_layers)] - state_list.append(states) - - return state_list - else: - return None + return [state[:, idx] for state in batch_states] + + return None def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """Concatenate a batch of decoder state to a packed state. @@ -2231,4 +2179,4 @@ def sampled_joint( # Add the adapter compatible modules to the registry for cls in [RNNTDecoder, RNNTJoint, SampledRNNTJoint]: if adapter_mixins.get_registered_adapter(cls) is None: - adapter_mixins.register_adapter(cls, cls) # base class is adapter compatible itself + adapter_mixins.register_adapter(cls, cls) # base class is adapter compatible itself \ No newline at end of file diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index d3d9b7cb52d6..707b0b358194 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -246,14 +246,11 @@ def batch_score_hypothesis( """ raise NotImplementedError() - def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): + def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): """ Create batch of decoder states. Args: - batch_states (list): batch of decoder states - ([L x (B, H)], [L x (B, H)]) - decoder_states (list of list): list of decoder states [B x ([L x (1, H)], [L x (1, H)])] diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 25becda6fa75..5148d99fad71 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -725,11 +725,11 @@ def time_sync_decoding( D = [] # Decode a batch of beam states and scores - beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(C, cache, beam_state) + beam_y, beam_state = self.decoder.batch_score_hypothesis(C, cache) # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax( - self.joint.joint(h_enc, beam_y) / self.softmax_temperature, dim=-1 + self.joint.joint(h_enc, torch.stack(beam_y)) / self.softmax_temperature, dim=-1 ) # [B, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) @@ -776,7 +776,7 @@ def time_sync_decoding( new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence + [int(k)]), - dec_state=self.decoder.batch_select_state(beam_state, j), + dec_state=beam_state[j], lm_state=hyp.lm_state, timestep=hyp.timestep[:] + [i], length=encoded_lengths, @@ -859,6 +859,7 @@ def align_length_sync_decoding( beam_state = self.decoder.initialize_state( torch.zeros(beam, device=h.device, dtype=h.dtype) ) # [L, B, H], [L, B, H] for LSTMS + beam_state = [self.decoder.batch_select_state(beam_state, 0)] # compute u_max as either a specific static limit, # or a multiple of current `h_length` dynamically. @@ -872,7 +873,7 @@ def align_length_sync_decoding( Hypothesis( y_sequence=[self.blank], score=0.0, - dec_state=self.decoder.batch_select_state(beam_state, 0), + dec_state=beam_state[0], timestep=[-1], length=0, ) @@ -921,12 +922,10 @@ def align_length_sync_decoding( # extract the states of the sub batch only. if isinstance(self.decoder, RNNTDecoder): # LSTM decoder, state is [layer x batch x hidden] - beam_state_ = [ - beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state)) - ] + beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) elif isinstance(self.decoder, StatelessTransducerDecoder): # stateless decoder, state is [batch x hidden] - beam_state_ = [beam_state[state_id][sub_batch_ids, :] for state_id in range(len(beam_state))] + beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) else: raise NotImplementedError("Unknown decoder type.") @@ -935,22 +934,21 @@ def align_length_sync_decoding( beam_state_ = beam_state # Decode a batch/sub-batch of beam states and scores - beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(B_, cache, beam_state_) + beam_y, beam_state_ = self.decoder.batch_score_hypothesis(B_, cache) # If only a subset of batch ids were updated (some were removed) if sub_batch_ids is not None: # For each state in the RNN (2 for LSTM) - for state_id in range(len(beam_state)): - # Update the current batch states with the sub-batch states (in the correct indices) - # These indices are specified by sub_batch_ids, the ids of samples which were updated. - if isinstance(self.decoder, RNNTDecoder): - # LSTM decoder, state is [layer x batch x hidden] - beam_state[state_id][:, sub_batch_ids, :] = beam_state_[state_id][...] - elif isinstance(self.decoder, StatelessTransducerDecoder): - # stateless decoder, state is [batch x hidden] - beam_state[state_id][sub_batch_ids, :] = beam_state_[state_id][...] - else: - raise NotImplementedError("Unknown decoder type.") + # Update the current batch states with the sub-batch states (in the correct indices) + # These indices are specified by sub_batch_ids, the ids of samples which were updated. + if isinstance(self.decoder, RNNTDecoder) or isinstance(self.decoder, StatelessTransducerDecoder): + # LSTM decoder, state is [layer x batch x hidden] + index=0 + for sub_batch_id in sub_batch_ids: + beam_state[sub_batch_id] = beam_state_[index] + index+=1 + else: + raise NotImplementedError("Unknown decoder type.") else: # If entire batch was updated, simply update all the states beam_state = beam_state_ @@ -963,7 +961,7 @@ def align_length_sync_decoding( # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax( - self.joint.joint(h_enc, beam_y) / self.softmax_temperature, dim=-1 + self.joint.joint(h_enc, torch.stack(beam_y)) / self.softmax_temperature, dim=-1 ) # [B=beam, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) @@ -1011,7 +1009,7 @@ def align_length_sync_decoding( new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence[:] + [int(k)]), - dec_state=self.decoder.batch_select_state(beam_state, h_states_idx), + dec_state=beam_state[h_states_idx], lm_state=hyp.lm_state, timestep=hyp.timestep[:] + [i], length=i, @@ -1084,7 +1082,7 @@ def modified_adaptive_expansion_search( # prepare the batched beam states beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.initialize_state( - torch.zeros(beam, device=h.device, dtype=h.dtype) + torch.zeros(1, device=h.device, dtype=h.dtype) ) # [L, B, H], [L, B, H] for LSTMS # Initialize first hypothesis for the beam (blank) @@ -1106,8 +1104,8 @@ def modified_adaptive_expansion_search( hyp.alignments = [[]] # Decode a batch of beam states and scores - beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(init_tokens, cache, beam_state) - state = self.decoder.batch_select_state(beam_state, 0) + beam_dec_out, beam_state = self.decoder.batch_score_hypothesis(init_tokens, cache) + state = beam_state[0] # Setup ngram LM: if self.ngram_lm: @@ -1267,18 +1265,10 @@ def modified_adaptive_expansion_search( break else: - # Initialize the beam states for the hypotheses in the expannsion list - beam_state = self.decoder.batch_initialize_states( - beam_state, - [hyp.dec_state for hyp in list_exp], - # [hyp.y_sequence for hyp in list_exp], # - ) - # Decode a batch of beam states and scores - beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis( + beam_dec_out, beam_state = self.decoder.batch_score_hypothesis( list_exp, cache, - beam_state, # self.language_model is not None, ) @@ -1300,7 +1290,7 @@ def modified_adaptive_expansion_search( for i, hyp in enumerate(list_exp): # Preserve the decoder logits for the current beam hyp.dec_out.append(beam_dec_out[i]) - hyp.dec_state = self.decoder.batch_select_state(beam_state, i) + hyp.dec_state = beam_state[i] # TODO: Setup LM if self.language_model is not None: @@ -1325,7 +1315,7 @@ def modified_adaptive_expansion_search( else: # Extract the log probabilities - beam_logp, _ = self.resolve_joint_output(beam_enc_out, beam_dec_out) + beam_logp, _ = self.resolve_joint_output(beam_enc_out, torch.stack(beam_dec_out)) beam_logp = beam_logp[:, 0, 0, :] # For all expansions, add the score for the blank label @@ -1334,7 +1324,7 @@ def modified_adaptive_expansion_search( # Preserve the decoder's output and state hyp.dec_out.append(beam_dec_out[i]) - hyp.dec_state = self.decoder.batch_select_state(beam_state, i) + hyp.dec_state = beam_state[i] # TODO: Setup LM if self.language_model is not None: @@ -1387,7 +1377,7 @@ def recombine_hypotheses(self, hypotheses: List[Hypothesis]) -> List[Hypothesis] else: final.append(hyp) - return hypotheses + return final def resolve_joint_output(self, enc_out: torch.Tensor, dec_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -1478,7 +1468,6 @@ def set_decoding_type(self, decoding_type: str): self.token_offset = DEFAULT_TOKEN_OFFSET - @dataclass class BeamRNNTInferConfig: beam_size: int @@ -1499,4 +1488,4 @@ class BeamRNNTInferConfig: ngram_lm_model: Optional[str] = None ngram_lm_alpha: Optional[float] = 0.0 hat_subtract_ilm: bool = False - hat_ilm_weight: float = 0.0 + hat_ilm_weight: float = 0.0 \ No newline at end of file From c11d09f6c02d2ab848f3d3f84375b7fa97fb82da Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 18 Sep 2024 19:07:12 +0400 Subject: [PATCH 02/13] fixes im base class Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt_abstract.py | 33 +++++++++---------- .../parts/submodules/rnnt_beam_decoding.py | 8 ++--- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index 707b0b358194..899024953d53 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -226,7 +226,7 @@ def score_hypothesis( raise NotImplementedError() def batch_score_hypothesis( - self, hypotheses: List[Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor] + self, hypotheses: List[Hypothesis], cache: Dict[Tuple[int], Any] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. @@ -234,30 +234,29 @@ def batch_score_hypothesis( Args: hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. - batch_states: List of torch.Tensor which represent the states of the RNN for this batch. - Each state is of shape [L, B, H] Returns: - Returns a tuple (b_y, b_states, lm_tokens) such that: - b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses. - b_state is a list of list of RNN states, each of shape [L, B, H]. - Represented as B x List[states]. - lm_token is a list of the final integer tokens of the hypotheses in the batch. + Returns a tuple (batch_dec_out, batch_dec_states) such that: + batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses. + batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states]. """ raise NotImplementedError() def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): """ - Create batch of decoder states. + Creates a stacked decoder states to be passed to prediction network - Args: - decoder_states (list of list): list of decoder states - [B x ([L x (1, H)], [L x (1, H)])] - - Returns: - batch_states (tuple): batch of decoder states - ([L x (B, H)], [L x (B, H)]) - """ + Args: + decoder_states (list of list of list of torch.Tensor): list of decoder states + [B, L, 1, H] + - B: Batch size. + - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). + - H: Dimensionality of the hidden state. + + Returns: + batch_states (list of torch.Tensor): batch of decoder states + [L x torch.Tensor[1 x B x H] + """ raise NotImplementedError() def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 5148d99fad71..b31894fae25c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -920,12 +920,8 @@ def align_length_sync_decoding( sub_batch_ids.remove(id) # extract the states of the sub batch only. - if isinstance(self.decoder, RNNTDecoder): - # LSTM decoder, state is [layer x batch x hidden] - beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) - elif isinstance(self.decoder, StatelessTransducerDecoder): - # stateless decoder, state is [batch x hidden] - beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) + if isinstance(self.decoder, RNNTDecoder) or isinstance(self.decoder, StatelessTransducerDecoder): + beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids)s else: raise NotImplementedError("Unknown decoder type.") From 271e2a6b0c32a809c758618812daeb0be797c28d Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 18 Sep 2024 19:47:11 +0400 Subject: [PATCH 03/13] clean up Signed-off-by: lilithgrigoryan --- nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index b31894fae25c..d8b352a309e6 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -921,7 +921,7 @@ def align_length_sync_decoding( # extract the states of the sub batch only. if isinstance(self.decoder, RNNTDecoder) or isinstance(self.decoder, StatelessTransducerDecoder): - beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids)s + beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) else: raise NotImplementedError("Unknown decoder type.") From 51074bf96d8d9e31d8cc2bb4e6604f0e72db532d Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 18 Sep 2024 15:47:48 +0000 Subject: [PATCH 04/13] Apply isort and black reformatting Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 64 +++++++++++-------- nemo/collections/asr/modules/rnnt_abstract.py | 15 +++-- .../parts/submodules/rnnt_beam_decoding.py | 9 +-- 3 files changed, 52 insertions(+), 36 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 1e5afff8bf33..305ca3788809 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), @@ -84,8 +83,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), @@ -316,17 +314,17 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device) ] return state - + def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network. Args: decoder_states (list of list of torch.Tensor): list of decoder states - [B, 1, C] + [B, 1, C] - B: Batch size. - C: Dimensionality of the hidden state. - + Returns: batch_states (list of torch.Tensor): batch of decoder states [[B x C]] """ @@ -380,7 +378,10 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to @classmethod def batch_replace_states_mask( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor, + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], + mask: torch.Tensor, ): """Replace states in dst_states with states from src_states using the mask""" # same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking @@ -388,7 +389,9 @@ def batch_replace_states_mask( @classmethod def batch_replace_states_all( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], ): """Replace states in dst_states with states from src_states""" dst_states[0].copy_(src_states[0]) @@ -447,7 +450,9 @@ def mask_select_states( return [states[0][mask]] def batch_score_hypothesis( - self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], + self, + hypotheses: List[rnnt_utils.Hypothesis], + cache: Dict[Tuple[int], Any], ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. @@ -506,9 +511,10 @@ def batch_score_hypothesis( cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) processed_idx += 1 - + return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] + class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMixin): """A Recurrent Neural Network Transducer Decoder / Prediction Network (RNN-T Prediction Network). An RNN-T Decoder/Prediction network, comprised of a stateful LSTM model. @@ -569,8 +575,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), @@ -579,8 +584,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), @@ -912,7 +916,9 @@ def score_hypothesis( return y, new_state, lm_token def batch_score_hypothesis( - self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], + self, + hypotheses: List[rnnt_utils.Hypothesis], + cache: Dict[Tuple[int], Any], ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. @@ -971,7 +977,7 @@ def batch_score_hypothesis( cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) processed_idx += 1 - + return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: @@ -980,11 +986,11 @@ def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]) -> List[t Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, L, 1, H] + [B, L, 1, H] - B: Batch size. - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). - H: Dimensionality of the hidden state. - + Returns: batch_states (list of torch.Tensor): batch of decoder states [L x torch.Tensor[1 x B x H] @@ -1012,7 +1018,7 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List # print("###", len(batch_states), batch_states[0].shape, self.pred_rnn_layers) if batch_states is not None: return [state[:, idx] for state in batch_states] - + return None def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: @@ -1057,7 +1063,9 @@ def batch_replace_states_mask( @classmethod def batch_replace_states_all( - cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor], + cls, + src_states: Tuple[torch.Tensor, torch.Tensor], + dst_states: Tuple[torch.Tensor, torch.Tensor], ): """Replace states in dst_states with states from src_states""" dst_states[0].copy_(src_states[0]) @@ -1201,8 +1209,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin) @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), @@ -1214,8 +1221,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" if not self._fuse_loss_wer: return { "outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), @@ -1995,7 +2001,11 @@ def forward( return losses, wer, wer_num, wer_denom def sampled_joint( - self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor, + self, + f: torch.Tensor, + g: torch.Tensor, + transcript: torch.Tensor, + transcript_lengths: torch.Tensor, ) -> torch.Tensor: """ Compute the sampled joint step of the network. @@ -2179,4 +2189,4 @@ def sampled_joint( # Add the adapter compatible modules to the registry for cls in [RNNTDecoder, RNNTJoint, SampledRNNTJoint]: if adapter_mixins.get_registered_adapter(cls) is None: - adapter_mixins.register_adapter(cls, cls) # base class is adapter compatible itself \ No newline at end of file + adapter_mixins.register_adapter(cls, cls) # base class is adapter compatible itself diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index 899024953d53..2dbeb58ded90 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -248,11 +248,11 @@ def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, L, 1, H] + [B, L, 1, H] - B: Batch size. - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). - H: Dimensionality of the hidden state. - + Returns: batch_states (list of torch.Tensor): batch of decoder states [L x torch.Tensor[1 x B x H] @@ -276,14 +276,19 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List @classmethod def batch_replace_states_mask( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor, + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], + mask: torch.Tensor, ): """Replace states in dst_states with states from src_states using the mask, in a way that does not synchronize with the CPU""" raise NotImplementedError() @classmethod def batch_replace_states_all( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], ): """Replace states in dst_states with states from src_states""" raise NotImplementedError() @@ -316,7 +321,7 @@ def batch_copy_states( value: Optional[float] = None, ) -> List[torch.Tensor]: """Copy states from new state to old state at certain indices. - + Args: old_states(list): packed decoder states (L x B x H, L x B x H) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index d8b352a309e6..c01f2363db75 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -921,7 +921,7 @@ def align_length_sync_decoding( # extract the states of the sub batch only. if isinstance(self.decoder, RNNTDecoder) or isinstance(self.decoder, StatelessTransducerDecoder): - beam_state_= (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) + beam_state_ = (beam_state[sub_batch_id] for sub_batch_id in sub_batch_ids) else: raise NotImplementedError("Unknown decoder type.") @@ -939,10 +939,10 @@ def align_length_sync_decoding( # These indices are specified by sub_batch_ids, the ids of samples which were updated. if isinstance(self.decoder, RNNTDecoder) or isinstance(self.decoder, StatelessTransducerDecoder): # LSTM decoder, state is [layer x batch x hidden] - index=0 + index = 0 for sub_batch_id in sub_batch_ids: beam_state[sub_batch_id] = beam_state_[index] - index+=1 + index += 1 else: raise NotImplementedError("Unknown decoder type.") else: @@ -1464,6 +1464,7 @@ def set_decoding_type(self, decoding_type: str): self.token_offset = DEFAULT_TOKEN_OFFSET + @dataclass class BeamRNNTInferConfig: beam_size: int @@ -1484,4 +1485,4 @@ class BeamRNNTInferConfig: ngram_lm_model: Optional[str] = None ngram_lm_alpha: Optional[float] = 0.0 hat_subtract_ilm: bool = False - hat_ilm_weight: float = 0.0 \ No newline at end of file + hat_ilm_weight: float = 0.0 From 05380d0149cc4efc7f634fb6b184a28c64b9b917 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 18 Sep 2024 20:01:56 +0400 Subject: [PATCH 05/13] remove potentially uninitialized local variable Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 1e5afff8bf33..12b314aa89cf 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -490,20 +490,20 @@ def batch_score_hypothesis( tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) dec_states = self.batch_stack_states([d_state for _, d_state in to_process]) - dec_out, dec_states = self.predict( + dec_outputs, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch ) # [B, 1, H], B x List([L, 1, H]) # Update final states and cache shared by entire batch. processed_idx = 0 for final_idx in range(final_batch): - if final[final_idx] is None: + if to_process and final[final_idx] is None: # Select sample's state from the batch state list new_state = self.batch_select_state(dec_states, processed_idx) # Cache [1, H] scores of the current y_j, and its corresponding state - final[final_idx] = (dec_out[processed_idx], new_state) - cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) + final[final_idx] = (dec_outputs[processed_idx], new_state) + cache[to_process[processed_idx][0]] = (dec_outputs[processed_idx], new_state) processed_idx += 1 From c4404c77c84d242996533338cb34f40728648bf2 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 9 Oct 2024 09:44:29 +0400 Subject: [PATCH 06/13] restore batch_intilize states funcname Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 8 ++++---- nemo/collections/asr/modules/rnnt_abstract.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index f7c80034b7f2..fb1943c3b4c3 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -315,7 +315,7 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: ] return state - def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): + def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network. @@ -493,7 +493,7 @@ def batch_score_hypothesis( # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.batch_stack_states([d_state for _, d_state in to_process]) + dec_states = self.batch_initilize_states([d_state for _, d_state in to_process]) dec_outputs, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch @@ -959,7 +959,7 @@ def batch_score_hypothesis( # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.batch_stack_states([d_state for _, d_state in to_process]) + dec_states = self.batch_initilize_states([d_state for _, d_state in to_process]) dec_out, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch @@ -980,7 +980,7 @@ def batch_score_hypothesis( return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] - def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: + def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """ Creates a stacked decoder states to be passed to prediction network diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index 2dbeb58ded90..7dba3368d026 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -242,7 +242,7 @@ def batch_score_hypothesis( """ raise NotImplementedError() - def batch_stack_states(self, decoder_states: List[List[torch.Tensor]]): + def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network From 7b9e70040e0896930813e455b633aa2cde11e543 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 9 Oct 2024 09:46:57 +0400 Subject: [PATCH 07/13] fix typo Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 8 ++++---- nemo/collections/asr/modules/rnnt_abstract.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index fb1943c3b4c3..840b8993558e 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -315,7 +315,7 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: ] return state - def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]): + def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network. @@ -493,7 +493,7 @@ def batch_score_hypothesis( # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.batch_initilize_states([d_state for _, d_state in to_process]) + dec_states = self.batch_initialize_states([d_state for _, d_state in to_process]) dec_outputs, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch @@ -959,7 +959,7 @@ def batch_score_hypothesis( # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) - dec_states = self.batch_initilize_states([d_state for _, d_state in to_process]) + dec_states = self.batch_initialize_states([d_state for _, d_state in to_process]) dec_out, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch @@ -980,7 +980,7 @@ def batch_score_hypothesis( return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] - def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: + def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """ Creates a stacked decoder states to be passed to prediction network diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index 7dba3368d026..fc2e20d1302b 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -242,7 +242,7 @@ def batch_score_hypothesis( """ raise NotImplementedError() - def batch_initilize_states(self, decoder_states: List[List[torch.Tensor]]): + def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network From 22d861fac4e39cfffbfce3939b3a085d69ceb935 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 10 Oct 2024 16:07:49 +0400 Subject: [PATCH 08/13] fix potentially uninitialized local variable Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 840b8993558e..345c24d7a274 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -499,18 +499,18 @@ def batch_score_hypothesis( tokens, state=dec_states, add_sos=False, batch_size=batch ) # [B, 1, H], B x List([L, 1, H]) - # Update final states and cache shared by entire batch. - processed_idx = 0 - for final_idx in range(final_batch): - if to_process and final[final_idx] is None: - # Select sample's state from the batch state list - new_state = self.batch_select_state(dec_states, processed_idx) - - # Cache [1, H] scores of the current y_j, and its corresponding state - final[final_idx] = (dec_outputs[processed_idx], new_state) - cache[to_process[processed_idx][0]] = (dec_outputs[processed_idx], new_state) - - processed_idx += 1 + # Update final states and cache shared by entire batch. + processed_idx = 0 + for final_idx in range(final_batch): + if to_process and final[final_idx] is None: + # Select sample's state from the batch state list + new_state = self.batch_select_state(dec_states, processed_idx) + + # Cache [1, H] scores of the current y_j, and its corresponding state + final[final_idx] = (dec_outputs[processed_idx], new_state) + cache[to_process[processed_idx][0]] = (dec_outputs[processed_idx], new_state) + + processed_idx += 1 return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] From b4ed45a37e723bd6103b7c41e2bcd4bfe204f501 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 10 Oct 2024 16:18:04 +0400 Subject: [PATCH 09/13] fix potentially uninitialized local variable in stateless transduser Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 345c24d7a274..b3e4d9fcf63c 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -965,18 +965,18 @@ def batch_score_hypothesis( tokens, state=dec_states, add_sos=False, batch_size=batch ) # [B, 1, H], B x List([L, 1, H]) - # Update final states and cache shared by entire batch. - processed_idx = 0 - for final_idx in range(final_batch): - if final[final_idx] is None: - # Select sample's state from the batch state list - new_state = self.batch_select_state(dec_states, processed_idx) - - # Cache [1, H] scores of the current y_j, and its corresponding state - final[final_idx] = (dec_out[processed_idx], new_state) - cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) - - processed_idx += 1 + # Update final states and cache shared by entire batch. + processed_idx = 0 + for final_idx in range(final_batch): + if final[final_idx] is None: + # Select sample's state from the batch state list + new_state = self.batch_select_state(dec_states, processed_idx) + + # Cache [1, H] scores of the current y_j, and its corresponding state + final[final_idx] = (dec_out[processed_idx], new_state) + cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) + + processed_idx += 1 return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final] From 710aeac13d000dc7a5e99a95006446430b342271 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 11 Oct 2024 12:24:00 +0400 Subject: [PATCH 10/13] fix test Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index b3e4d9fcf63c..5825e48a78e5 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1037,7 +1037,7 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to for state_id in range(len(batch_states[0])): batch_list = [] for sample_id in range(len(batch_states)): - tensor = torch.stack(batch_states[sample_id][state_id]) # [L, H] + tensor = torch.stack(batch_states[sample_id][state_id]) if not isinstance(batch_states[sample_id][state_id], torch.Tensor) else batch_states[sample_id][state_id] # [L, H] tensor = tensor.unsqueeze(0) # [1, L, H] batch_list.append(tensor) From f3a10c495a2b6dfa13d14aff74bdeb8d7b9c86fb Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 11 Oct 2024 08:25:48 +0000 Subject: [PATCH 11/13] Apply isort and black reformatting Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 5825e48a78e5..297ca867639f 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1037,7 +1037,11 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to for state_id in range(len(batch_states[0])): batch_list = [] for sample_id in range(len(batch_states)): - tensor = torch.stack(batch_states[sample_id][state_id]) if not isinstance(batch_states[sample_id][state_id], torch.Tensor) else batch_states[sample_id][state_id] # [L, H] + tensor = ( + torch.stack(batch_states[sample_id][state_id]) + if not isinstance(batch_states[sample_id][state_id], torch.Tensor) + else batch_states[sample_id][state_id] + ) # [L, H] tensor = tensor.unsqueeze(0) # [1, L, H] batch_list.append(tensor) From 098a1c2c6441af06fb383b53b18fb243962e31b1 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 11 Oct 2024 21:59:37 +0400 Subject: [PATCH 12/13] fix docstring, rm comment Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 12 ++++++------ nemo/collections/asr/modules/rnnt_abstract.py | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 5825e48a78e5..a58c6c9fc4d0 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -986,17 +986,18 @@ def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]) -> L Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, L, 1, H] + [B, layer, L, H] - B: Batch size. - - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). + - layer: e.g., for LSTM, this is 2: hidden and cell states + - L: Number of layers in prediction RNN. - H: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states - [L x torch.Tensor[1 x B x H] + [layer x torch.Tensor[L x B x H] """ - # stack decoder states into tensor of shape [B x L x 1 x H] - # permute to the target shape [L x 1 x B x H] + # stack decoder states into tensor of shape [B x layers x L x H] + # permute to the target shape [layers x L x B x H] stacked_states = torch.stack([torch.stack(decoder_state) for decoder_state in decoder_states]) permuted_states = stacked_states.permute(1, 2, 0, 3) @@ -1015,7 +1016,6 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List (tuple): decoder states for given id ([L x (1, H)], [L x (1, H)]) """ - # print("###", len(batch_states), batch_states[0].shape, self.pred_rnn_layers) if batch_states is not None: return [state[:, idx] for state in batch_states] diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index fc2e20d1302b..deaf00e92001 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -248,14 +248,15 @@ def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]): Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, L, 1, H] + [B, layer, L, H] - B: Batch size. - - L: Number of layers in prediction RNN (e.g., for LSTM, this is 2: hidden and cell states). + - layer: e.g., for LSTM, this is 2: hidden and cell states + - L: Number of layers in prediction RNN. - H: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states - [L x torch.Tensor[1 x B x H] + [layer x torch.Tensor[L x B x H] """ raise NotImplementedError() From 0917cebdf5024e2d2a97e5fe71251cb9566b93c5 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 11 Oct 2024 22:31:00 +0400 Subject: [PATCH 13/13] fix dosctrings Signed-off-by: lilithgrigoryan --- nemo/collections/asr/modules/rnnt.py | 6 +++--- nemo/collections/asr/modules/rnnt_abstract.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index a58c6c9fc4d0..7be4b2e1916f 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -986,15 +986,15 @@ def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]) -> L Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, layer, L, H] + [B, C, L, H] - B: Batch size. - - layer: e.g., for LSTM, this is 2: hidden and cell states + - C: e.g., for LSTM, this is 2: hidden and cell states - L: Number of layers in prediction RNN. - H: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states - [layer x torch.Tensor[L x B x H] + [C x torch.Tensor[L x B x H] """ # stack decoder states into tensor of shape [B x layers x L x H] # permute to the target shape [layers x L x B x H] diff --git a/nemo/collections/asr/modules/rnnt_abstract.py b/nemo/collections/asr/modules/rnnt_abstract.py index deaf00e92001..c895fc6deaf1 100644 --- a/nemo/collections/asr/modules/rnnt_abstract.py +++ b/nemo/collections/asr/modules/rnnt_abstract.py @@ -248,15 +248,15 @@ def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]): Args: decoder_states (list of list of list of torch.Tensor): list of decoder states - [B, layer, L, H] + [B, C, L, H] - B: Batch size. - - layer: e.g., for LSTM, this is 2: hidden and cell states + - C: e.g., for LSTM, this is 2: hidden and cell states - L: Number of layers in prediction RNN. - H: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states - [layer x torch.Tensor[L x B x H] + [C x torch.Tensor[L x B x H] """ raise NotImplementedError()