Skip to content

Commit

Permalink
[ASR] Fix GPU memory leak in transcribe_speech.py (#7249)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <rlangman@nvidia.com>
  • Loading branch information
rlangman authored Aug 17, 2023
1 parent 2a5ecce commit b7c8ef1
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,18 @@ def transcribe_partial_audio(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
logits, logits_len = outputs[0], outputs[1]

if isinstance(asr_model, EncDecHybridRNNTCTCModel) and decoder_type == "ctc":
logits = asr_model.ctc_decoder(encoder_output=logits)

logits = logits.cpu()

if logprobs:
logits = logits.numpy()
# dump log probs per file
for idx in range(logits.shape[0]):
lg = logits[idx][: logits_len[idx]]
hypotheses.append(lg.cpu().numpy())
hypotheses.append(lg)
else:
current_hypotheses, all_hyp = decode_function(logits, logits_len, return_hypotheses=return_hypotheses,)

Expand Down

0 comments on commit b7c8ef1

Please sign in to comment.