Skip to content

Commit

Permalink
remove stacking operation from batched functions (#10524)
Browse files Browse the repository at this point in the history
* remove stacking operations

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fixes im base class

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* clean up

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com>

* remove potentially uninitialized local variable

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* restore batch_intilize states funcname

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fix typo

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fix potentially uninitialized local variable

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fix potentially uninitialized local variable
in stateless transduser

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fix test

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com>

* fix docstring, rm comment

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

* fix dosctrings

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>

---------

Signed-off-by: lilithgrigoryan <lgrigoryan@nvidia.com>
Signed-off-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com>
Co-authored-by: lilithgrigoryan <lgrigoryan@nvidia.com>
Co-authored-by: lilithgrigoryan <lilithgrigoryan@users.noreply.github.com>
  • Loading branch information
3 people authored and Yashaswi Karnati committed Oct 20, 2024
1 parent ec4c840 commit 9a9260a
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 194 deletions.
215 changes: 86 additions & 129 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,20 +315,18 @@ def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
]
return state

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]):
"""
Create batch of decoder states.
Creates a stacked decoder states to be passed to prediction network.
Args:
batch_states (list): batch of decoder states
([(B, H)])
decoder_states (list of list): list of decoder states
[B x ([(1, C)]]
decoder_states (list of list of torch.Tensor): list of decoder states
[B, 1, C]
- B: Batch size.
- C: Dimensionality of the hidden state.
Returns:
batch_states (tuple): batch of decoder states
([(B, C)])
batch_states (list of torch.Tensor): batch of decoder states [[B x C]]
"""
new_state = torch.stack([s[0] for s in decoder_states])

Expand Down Expand Up @@ -452,86 +450,69 @@ def mask_select_states(
return [states[0][mask]]

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
self,
hypotheses: List[rnnt_utils.Hypothesis],
cache: Dict[Tuple[int], Any],
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
"""
Used for batched beam search algorithms. Similar to score_hypothesis method.
Args:
hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis.
cache: Dict which contains a cache to avoid duplicate computations.
batch_states: List of torch.Tensor which represent the states of the RNN for this batch.
Each state is of shape [L, B, H]
Returns:
Returns a tuple (b_y, b_states, lm_tokens) such that:
b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses.
b_state is a list of list of RNN states, each of shape [L, B, H].
Represented as B x List[states].
lm_token is a list of the final integer tokens of the hypotheses in the batch.
Returns a tuple (batch_dec_out, batch_dec_states) such that:
batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses.
batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states].
"""
final_batch = len(hypotheses)

if final_batch == 0:
raise ValueError("No hypotheses was provided for the batch!")

_p = next(self.parameters())
device = _p.device
dtype = _p.dtype

tokens = []
process = []
done = [None for _ in range(final_batch)]
to_process = []
final = [None for _ in range(final_batch)]

# For each hypothesis, cache the last token of the sequence and the current states
for i, hyp in enumerate(hypotheses):
for final_idx, hyp in enumerate(hypotheses):
sequence = tuple(hyp.y_sequence)

if sequence in cache:
done[i] = cache[sequence]
final[final_idx] = cache[sequence]
else:
tokens.append(hyp.y_sequence[-1])
process.append((sequence, hyp.dec_state))
to_process.append((sequence, hyp.dec_state))

if process:
batch = len(process)
if to_process:
batch = len(to_process)

# convert list of tokens to torch.Tensor, then reshape.
tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1)
dec_states = self.initialize_state(tokens) # [B, C]
dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process])
dec_states = self.batch_initialize_states([d_state for _, d_state in to_process])

y, dec_states = self.predict(
dec_outputs, dec_states = self.predict(
tokens, state=dec_states, add_sos=False, batch_size=batch
) # [B, 1, H], List([L, 1, H])

dec_states = tuple(state.to(dtype=dtype) for state in dec_states)
) # [B, 1, H], B x List([L, 1, H])

# Update done states and cache shared by entire batch.
j = 0
for i in range(final_batch):
if done[i] is None:
# Select sample's state from the batch state list
new_state = self.batch_select_state(dec_states, j)
# Update final states and cache shared by entire batch.
processed_idx = 0
for final_idx in range(final_batch):
if to_process and final[final_idx] is None:
# Select sample's state from the batch state list
new_state = self.batch_select_state(dec_states, processed_idx)

# Cache [1, H] scores of the current y_j, and its corresponding state
done[i] = (y[j], new_state)
cache[process[j][0]] = (y[j], new_state)
# Cache [1, H] scores of the current y_j, and its corresponding state
final[final_idx] = (dec_outputs[processed_idx], new_state)
cache[to_process[processed_idx][0]] = (dec_outputs[processed_idx], new_state)

j += 1
processed_idx += 1

# Set the incoming batch states with the new states obtained from `done`.
batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done])

# Create batch of all output scores
# List[1, 1, H] -> [B, 1, H]
batch_y = torch.stack([y_j for y_j, d_state in done])

# Extract the last tokens from all hypotheses and convert to a tensor
lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view(
final_batch
)

return batch_y, batch_states, lm_tokens
return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final]


class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMixin):
Expand Down Expand Up @@ -935,23 +916,21 @@ def score_hypothesis(
return y, new_state, lm_token

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
self,
hypotheses: List[rnnt_utils.Hypothesis],
cache: Dict[Tuple[int], Any],
) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
"""
Used for batched beam search algorithms. Similar to score_hypothesis method.
Args:
hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis.
cache: Dict which contains a cache to avoid duplicate computations.
batch_states: List of torch.Tensor which represent the states of the RNN for this batch.
Each state is of shape [L, B, H]
Returns:
Returns a tuple (b_y, b_states, lm_tokens) such that:
b_y is a torch.Tensor of shape [B, 1, H] representing the scores of the last tokens in the Hypotheses.
b_state is a list of list of RNN states, each of shape [L, B, H].
Represented as B x List[states].
lm_token is a list of the final integer tokens of the hypotheses in the batch.
Returns a tuple (batch_dec_out, batch_dec_states) such that:
batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses.
batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states].
"""
final_batch = len(hypotheses)

Expand All @@ -960,90 +939,69 @@ def batch_score_hypothesis(

_p = next(self.parameters())
device = _p.device
dtype = _p.dtype

tokens = []
process = []
done = [None for _ in range(final_batch)]
to_process = []
final = [None for _ in range(final_batch)]

# For each hypothesis, cache the last token of the sequence and the current states
for i, hyp in enumerate(hypotheses):
for final_idx, hyp in enumerate(hypotheses):
sequence = tuple(hyp.y_sequence)

if sequence in cache:
done[i] = cache[sequence]
final[final_idx] = cache[sequence]
else:
tokens.append(hyp.y_sequence[-1])
process.append((sequence, hyp.dec_state))
to_process.append((sequence, hyp.dec_state))

if process:
batch = len(process)
if to_process:
batch = len(to_process)

# convert list of tokens to torch.Tensor, then reshape.
tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1)
dec_states = self.initialize_state(tokens.to(dtype=dtype)) # [L, B, H]
dec_states = self.batch_initialize_states(dec_states, [d_state for seq, d_state in process])
dec_states = self.batch_initialize_states([d_state for _, d_state in to_process])

y, dec_states = self.predict(
dec_out, dec_states = self.predict(
tokens, state=dec_states, add_sos=False, batch_size=batch
) # [B, 1, H], List([L, 1, H])

dec_states = tuple(state.to(dtype=dtype) for state in dec_states)

# Update done states and cache shared by entire batch.
j = 0
for i in range(final_batch):
if done[i] is None:
# Select sample's state from the batch state list
new_state = self.batch_select_state(dec_states, j)
) # [B, 1, H], B x List([L, 1, H])

# Cache [1, H] scores of the current y_j, and its corresponding state
done[i] = (y[j], new_state)
cache[process[j][0]] = (y[j], new_state)
# Update final states and cache shared by entire batch.
processed_idx = 0
for final_idx in range(final_batch):
if final[final_idx] is None:
# Select sample's state from the batch state list
new_state = self.batch_select_state(dec_states, processed_idx)

j += 1
# Cache [1, H] scores of the current y_j, and its corresponding state
final[final_idx] = (dec_out[processed_idx], new_state)
cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state)

# Set the incoming batch states with the new states obtained from `done`.
batch_states = self.batch_initialize_states(batch_states, [d_state for y_j, d_state in done])

# Create batch of all output scores
# List[1, 1, H] -> [B, 1, H]
batch_y = torch.stack([y_j for y_j, d_state in done])

# Extract the last tokens from all hypotheses and convert to a tensor
lm_tokens = torch.tensor([h.y_sequence[-1] for h in hypotheses], device=device, dtype=torch.long).view(
final_batch
)
processed_idx += 1

return batch_y, batch_states, lm_tokens
return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final]

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""
Create batch of decoder states.
Creates a stacked decoder states to be passed to prediction network
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
decoder_states (list of list of list of torch.Tensor): list of decoder states
[B, C, L, H]
- B: Batch size.
- C: e.g., for LSTM, this is 2: hidden and cell states
- L: Number of layers in prediction RNN.
- H: Dimensionality of the hidden state.
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
batch_states (list of torch.Tensor): batch of decoder states
[C x torch.Tensor[L x B x H]
"""
# LSTM has 2 states
new_states = [[] for _ in range(len(decoder_states[0]))]
for layer in range(self.pred_rnn_layers):
for state_id in range(len(decoder_states[0])):
# batch_states[state_id][layer] = torch.stack([s[state_id][layer] for s in decoder_states])
new_state_for_layer = torch.stack([s[state_id][layer] for s in decoder_states])
new_states[state_id].append(new_state_for_layer)
# stack decoder states into tensor of shape [B x layers x L x H]
# permute to the target shape [layers x L x B x H]
stacked_states = torch.stack([torch.stack(decoder_state) for decoder_state in decoder_states])
permuted_states = stacked_states.permute(1, 2, 0, 3)

for state_id in range(len(decoder_states[0])):
new_states[state_id] = torch.stack([state for state in new_states[state_id]])

return new_states
return list(permuted_states.contiguous())

def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]:
"""Get decoder state from batch of states, for given id.
Expand All @@ -1059,14 +1017,9 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List
([L x (1, H)], [L x (1, H)])
"""
if batch_states is not None:
state_list = []
for state_id in range(len(batch_states)):
states = [batch_states[state_id][layer][idx] for layer in range(self.pred_rnn_layers)]
state_list.append(states)
return [state[:, idx] for state in batch_states]

return state_list
else:
return None
return None

def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Concatenate a batch of decoder state to a packed state.
Expand All @@ -1084,7 +1037,11 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to
for state_id in range(len(batch_states[0])):
batch_list = []
for sample_id in range(len(batch_states)):
tensor = torch.stack(batch_states[sample_id][state_id]) # [L, H]
tensor = (
torch.stack(batch_states[sample_id][state_id])
if not isinstance(batch_states[sample_id][state_id], torch.Tensor)
else batch_states[sample_id][state_id]
) # [L, H]
tensor = tensor.unsqueeze(0) # [1, L, H]
batch_list.append(tensor)

Expand Down
Loading

0 comments on commit 9a9260a

Please sign in to comment.