Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ngram lm fusion for RNNT maes decoding #6118

Merged
merged 32 commits into from
Mar 14, 2023
Merged

Conversation

andrusenkoau
Copy link
Collaborator

What does this PR do ?

Ngram LM fusion for RNNT modified adaptive expansion search (maes) decoding.

Collection: [ASR]

Changelog

  • Add ngram_lm option to maes decoding algorithm
  • Add EncDecRNNTBPEModel and EncDecRNNTModel to kenlm_utils.py for ngram lm building
  • Add new file eval_beamsearch_ngram_transducer.py (base on old Vahid's branch for ngram rnnt beam search) for testing the ngram RNNT decoding

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

andrusenkoau and others added 8 commits February 24, 2023 07:21
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
… for maes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
words_count = 0
chars_count = 0
if preds_output_file:
out_file = open(preds_output_file, 'w')

Check warning

Code scanning / CodeQL

File is not always closed

File may not be closed if an exception is raised.
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks overall good, needs minor changes

@@ -1122,6 +1159,8 @@ def modified_adaptive_expansion_search(
timestep=hyp.timestep[:],
length=t,
)
if self.ngram_lm:
new_hyp.ngram_lm_state = hyp.ngram_lm_state.__deepcopy__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need deepcopy ? It's very expensive for large objects

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, deepcopy is redundant here

@@ -213,6 +227,9 @@ def __init__(
language_model: Optional[Dict[str, Any]] = None,
softmax_temperature: float = 1.0,
preserve_alignments: bool = False,
ngram_lm_model: Optional[str] = None,
ngram_lm_alpha: float = 0.0,
tokens_type: str = "subword",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not request used to provide token type. Determine if from the Ngram or from model type of possible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I redid it as in ctc_beam_decoding.py by set_decoding_type function

hyp_i.score + float(logp[hyp_j.y_sequence[pref_id]]) + self.ngram_lm_alpha * lm_score
)
else:
curr_score = hyp_i.score + float(logp[hyp_j.y_sequence[pref_id]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keep the original code and just add if branch after it to add the lm score ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable, thanks)

lm_score, next_state = self.compute_ngram_score(next_state, int(hyp_j.y_sequence[k + 1]))
curr_score += float(logp[hyp_j.y_sequence[k + 1]]) + self.ngram_lm_alpha * lm_score
else:
curr_score += float(logp[hyp_j.y_sequence[k + 1]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

# https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html


import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of two different script for beam search, would it be possible to use high level beam search API of RNNT and merge this script whth the one we already have?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to check this possibility

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994, I tried to merge RNNT beam search decoding into eval_beamsearch_ngram_transducer.py, but it looks a bit overloaded due to different decoding parameters and methods for CTC and RNNT. Furthermore, we need to add the option of logits returning for rnnt_model.transcribe method. We can stop at the current version of the two different files for this PR. I can try to merge them in the future separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transducer does not compute full logits, you can get back alignment dangling matrix with preserve alignments and return hypothesis so nj need for additional logprobs flag.

I suppose we can keep separate script for now though. Let me know when PR is finalized

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant RNNT encoder logits which are used for asr_model.decoding.rnnt_decoder_predictions_tensor by the analogy with model.decoding.ctc_decoder_predictions_tensor as in eval_beamsearch_ngram_transducer.py script.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that, but adding an argument to RNNT transcribe just for that chances signature of RNNT and CTC. You can use logprobs flag but you won't get actual logprobs back but encoder output. Please use hypothesis output for encoder logits or repurpose logprobs as encoder output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will try to do it in another PR.
Could you check the current PR and approve it if everything is OK? I finalized all the work.

andrusenkoau and others added 7 commits March 1, 2023 04:16
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you used the same formula used in my branch to calculate the lm scores?

# Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
# TOKEN_OFFSET for BPE-based models
if decoding_type == 'subword':
self.token_offset = 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define this number in one of the files and then import it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I imported token_offset variable from train_kenlm.py script.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, import from train_kenlm.py was bad idea. Do you have any suggestion where in nemo I can add default token_offset value? I can create nemo/collections/asr/parts/submodules/decoder_constants.py with this parameter.

BTW: CTC beam search also has hardcoded value for token_offset.

Copy link
Collaborator

@titu1994 titu1994 Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the offset inside of the CTC beam decoding file. The offset is not used for RNNT. Import the constant from the Nemo CTC beam search decoding file inside of train kenlm file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will add offset value inside ctc_beam_decodding.py. The idea is to use the same n-gram lm model for CTC_bpe and RNNT_bpe models. That is why I want to add offset for RNNT_bpe model too.

@andrusenkoau
Copy link
Collaborator Author

Have you used the same formula used in my branch to calculate the lm scores?

@VahidooX -- yes, I used your formula except for ngram_lm_beta scaling which I removed. I did not found any lm_beta mention in papers for external LM fusion for RNNT. Some authors used hypothesis length normalization in the end of beam search algorithm (final_hyp_score / len(hyp_text)). This technique did not work for me.

andrusenkoau and others added 4 commits March 9, 2023 04:55
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor changes required - let's keep signatures of the decoding strategies close to each other

@@ -202,6 +203,9 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id + joint.num_extra_outputs
)

if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
self.decoding.set_decoding_type('subword')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding the Tokenizer here ? And also the vocabulary for char models ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these for the RNNT beam decoding? MAES decoding works with default int values of predicted labels. The token offset is needed only for the kenlm score computation inside separate function that does not shift labels int values for hypothesis results.

"""

next_state = kenlm.State()
lm_score = self.ngram_lm.BaseScore(current_lm_state, chr(label + self.token_offset), next_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Token offset should only be applied for subword lms not char

Copy link
Collaborator Author

@andrusenkoau andrusenkoau Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was the bug. I fixed it.

@@ -0,0 +1,264 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to get updated to 2022.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done)

# limitations under the License.
#

# This script would evaluate an N-gram language model trained with KenLM library (https://github.com/kpu/kenlm) in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs here needs update.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to merge eval_beamsearch_ngram_transducer.py with eval_beamsearch_ngram.py in another PR. Could we leave eval_beamsearch_ngram_transducer.py as is now? I will update docs according to new eval_beamsearch_ngram.py.

# --decoding_mode=maes
# ...
#
# You may find more info on how to use this script at:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add it to the main doc page that we now support KenLM for both CTC and Transducer models and add the link to this script to the main page at the following address.
Please also add it to the main README page of nemo doc if missing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it in another PR (merging eval_beamsearch_ngram_transducer.py with eval_beamsearch_ngram.py) which I mentioned above.

@@ -44,6 +44,7 @@
import torch

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you import the DEFAULT_TOKEN_OFFSET, then why not use DEFAULT_TOKEN_OFFSET instead of TOKEN_OFFSET in the code?
Other way is to use the name TOKEN_OFFSET instead of DEFAULT_TOKEN_OFFSET?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left DEFAULT_TOKEN_OFFSET only

andrusenkoau and others added 5 commits March 13, 2023 23:28
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Copy link
Collaborator

@VahidooX VahidooX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may merge it now and address the rest of the comments in another PR.

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for now

@titu1994 titu1994 merged commit 3dbc64e into NVIDIA:main Mar 14, 2023
titu1994 pushed a commit to titu1994/NeMo that referenced this pull request Mar 24, 2023
* add parameters for ngram_lm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add parameters for ngram lm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add RNNT model types for kenlm training

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add ngram lm fusion to maes decoding mode

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add a script for the rnnt beam search decoding with a ngram lm fusion for maes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix autocast

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* typing fix

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add set_decoding_type function

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* remove tokens_type

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove tokens_type from config

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* import token_offset from train_kenlm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add DEFAULT_TOKEN_OFFSET variable

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix applying token_offset for char models

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* fixe copyright year

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* leave DEFAULT_TOKEN_OFFSET only

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

---------

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
* add parameters for ngram_lm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add parameters for ngram lm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add RNNT model types for kenlm training

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add ngram lm fusion to maes decoding mode

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add a script for the rnnt beam search decoding with a ngram lm fusion for maes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix autocast

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* typing fix

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* minor fixes

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* add set_decoding_type function

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* remove tokens_type

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove tokens_type from config

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* import token_offset from train_kenlm

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add DEFAULT_TOKEN_OFFSET variable

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix applying token_offset for char models

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* fixe copyright year

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

* leave DEFAULT_TOKEN_OFFSET only

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>

---------

Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants