diff --git a/docs/source/nlp/models.rst b/docs/source/nlp/models.rst index 932be201bfb2..ad50d976db9f 100755 --- a/docs/source/nlp/models.rst +++ b/docs/source/nlp/models.rst @@ -9,6 +9,7 @@ NeMo's NLP collection supports provides the following task-specific models: :maxdepth: 1 punctuation_and_capitalization_models + spellchecking_asr_customization token_classification joint_intent_slot text_classification diff --git a/docs/source/nlp/spellchecking_asr_customization.rst b/docs/source/nlp/spellchecking_asr_customization.rst new file mode 100644 index 000000000000..f9009b520361 --- /dev/null +++ b/docs/source/nlp/spellchecking_asr_customization.rst @@ -0,0 +1,128 @@ +.. _spellchecking_asr_customization: + +SpellMapper (Spellchecking ASR Customization) Model +===================================================== + +SpellMapper is a non-autoregressive model for postprocessing of ASR output. It gets as input a single ASR hypothesis (text) and a custom vocabulary and predicts which fragments in the ASR hypothesis should be replaced by which custom words/phrases if any. Unlike traditional spellchecking approaches, which aim to correct known words using language models, SpellMapper's goal is to correct highly specific user terms, out-of-vocabulary (OOV) words or spelling variations (e.g., "John Koehn", "Jon Cohen"). + +This model is an alternative to word boosting/shallow fusion approaches: + +- does not require retraining ASR model; +- does not require beam-search/language model (LM); +- can be applied on top of any English ASR model output; + +Model Architecture +------------------ +Though SpellMapper is based on `BERT `__ :cite:`nlp-ner-devlin2018bert` architecture, it uses some non-standard tricks that make it different from other BERT-based models: + +- ten separators (``[SEP]`` tokens) are used to combine the ASR hypothesis and ten candidate phrases into a single input; +- the model works on character level; +- subword embeddings are concatenated to the embeddings of each character that belongs to this subword; + + .. code:: + + Example input: [CLS] a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o [SEP] d i d i e r _ s a u m o n [SEP] a s t r o n o m i e [SEP] t r i s t a n _ g u i l l o t [SEP] ... + Input segments: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 + Example output: 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3 0 ... + +The model calculates logits for each character x 11 labels: + +- ``0`` - character doesn't belong to any candidate, +- ``1..10`` - character belongs to candidate with this id. + +At inference average pooling is applied to calculate replacement probability for the whole fragments. + +Quick Start Guide +----------------- + +We recommend you try this model in a Jupyter notebook (need GPU): +`NeMo/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb `__. + +A pretrained English checkpoint can be found at `HuggingFace `__. + +An example inference pipeline can be found here: `NeMo/examples/nlp/spellchecking_asr_customization/run_infer.sh `__. + +An example script on how to train the model can be found here: `NeMo/examples/nlp/spellchecking_asr_customization/run_training.sh `__. + +An example script on how to train on large datasets can be found here: `NeMo/examples/nlp/spellchecking_asr_customization/run_training_tarred.sh `__. + +The default configuration file for the model can be found here: `NeMo/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml `__. + +.. _dataset_spellchecking_asr_customization: + +Input/Output Format at Inference stage +-------------------------------------- +Here we describe input/output format of the SpellMapper model. + +.. note:: + + If you use `inference pipeline `__ this format will be hidden inside and you only need to provide an input manifest and user vocabulary and you will get a corrected manifest. + +An input line should consist of 4 tab-separated columns: + 1. text of ASR-hypothesis + 2. texts of 10 candidates separated by semicolon + 3. 1-based ids of non-dummy candidates, separated by space + 4. approximate start/end coordinates of non-dummy candidates (correspond to ids in third column) + +Example input (in one line): + +.. code:: + + t h e _ t a r a s i c _ o o r d a _ i s _ a _ p a r t _ o f _ t h e _ a o r t a _ l o c a t e d _ i n _ t h e _ t h o r a x + h e p a t i c _ c i r r h o s i s;u r a c i l;c a r d i a c _ a r r e s t;w e a n;a p g a r;p s y c h o m o t o r;t h o r a x;t h o r a c i c _ a o r t a;a v f;b l o c k a d e d + 1 2 6 7 8 9 10 + CUSTOM 6 23;CUSTOM 4 10;CUSTOM 4 15;CUSTOM 56 62;CUSTOM 5 19;CUSTOM 28 31;CUSTOM 39 48 + +Each line in SpellMapper output is tab-separated and consists of 4 columns: + 1. ASR-hypothesis (same as in input) + 2. 10 candidates separated by semicolon (same as in input) + 3. fragment predictions, separated by semicolon, each prediction is a tuple (start, end, candidate_id, probability) + 4. letter predictions - candidate_id predicted for each letter (this is only for debug purposes) + +Example output (in one line): + +.. code:: + + t h e _ t a r a s i c _ o o r d a _ i s _ a _ p a r t _ o f _ t h e _ a o r t a _ l o c a t e d _ i n _ t h e _ t h o r a x + h e p a t i c _ c i r r h o s i s;u r a c i l;c a r d i a c _ a r r e s t;w e a n;a p g a r;p s y c h o m o t o r;t h o r a x;t h o r a c i c _ a o r t a;a v f;b l o c k a d e d + 56 62 7 0.99998;4 20 8 0.95181;12 20 8 0.44829;4 17 8 0.99464;12 17 8 0.97645 + 8 8 8 0 8 8 8 8 8 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 7 7 7 7 7 + +Training Data Format +-------------------- + +For training, the data should consist of 5 files: + +- ``config.json`` - BERT config +- ``label_map.txt`` - labels from 0 to 10, do not change +- ``semiotic_classes.txt`` - currently there are only two classes: ``PLAIN`` and ``CUSTOM``, do not change +- ``train.tsv`` - training examples +- ``test.tsv`` - validation examples + +Note that since all these examples are synthetic, we do not reserve a set for final testing. Instead, we run `inference pipeline `__ and compare resulting word error rate (WER) to the WER of baseline ASR output. + +One (non-tarred) training example should consist of 4 tab-separated columns: + 1. text of ASR-hypothesis + 2. texts of 10 candidates separated by semicolon + 3. 1-based ids of correct candidates, separated by space, or 0 if none + 4. start/end coordinates of correct candidates (correspond to ids in third column) + +Example (in one line): + +.. code:: + + a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o + d i d i e r _ s a u m o n;a s t r o n o m i e;t r i s t a n _ g u i l l o t;t r i s t e s s e;m o n a d e;c h r i s t i a n;a s t r o n o m e r;s o l o m o n;d i d i d i d i d i;m e r c y + 1 3 + CUSTOM 12 23;CUSTOM 28 41 + +For data preparation see `this script `__ + + +References +---------- + +.. bibliography:: nlp_all.bib + :style: plain + :labelprefix: NLP-NER + :keyprefix: nlp-ner- diff --git a/docs/source/starthere/tutorials.rst b/docs/source/starthere/tutorials.rst index cb81aecc1109..9c960053398b 100644 --- a/docs/source/starthere/tutorials.rst +++ b/docs/source/starthere/tutorials.rst @@ -130,6 +130,9 @@ To run a tutorial: * - NLP - Punctuation and Capitalization - `Punctuation and Capitalization `_ + * - NLP + - Spellchecking ASR Customization - SpellMapper + - `Spellchecking ASR Customization - SpellMapper `_ * - NLP - Entity Linking - `Entity Linking `_ diff --git a/examples/nlp/spellchecking_asr_customization/README.md b/examples/nlp/spellchecking_asr_customization/README.md new file mode 100644 index 000000000000..2d83fd8d11ad --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/README.md @@ -0,0 +1,32 @@ +# SpellMapper - spellchecking model for ASR Customization + +This model is inspired by Microsoft's paper https://arxiv.org/pdf/2203.00888.pdf, but does not repeat its implementation. +The goal is to build a model that gets as input a single ASR hypothesis (text) and a vocabulary of custom words/phrases and predicts which fragments in the ASR hypothesis should be replaced by which custom words/phrases if any. +Our model is non-autoregressive (NAR) based on transformer architecture (BERT with multiple separators). + +As initial data we use about 5 mln entities from [YAGO corpus](https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/yago-naga/yago/downloads/). These entities are short phrases from Wikipedia headings. +In order to get misspelled predictions we feed these data to TTS model and then to ASR model. +Having a "parallel" corpus of "correct + misspelled" phrases, we use statistical machine translation techniques to create a dictionary of possible ngram mappings with their respective frequencies. +We create an auxiliary algorithm that takes as input a sentence (ASR hypothesis) and a large custom dictionary (e.g. 5000 phrases) and selects top 10 candidate phrases that are probably contained in this sentence in a misspelled way. +The task of our final neural model is to predict which fragments in the ASR hypothesis should be replaced by which of top-10 candidate phrases if any. + +The pipeline consists of multiple steps: + +1. Download or generate training data. + See `https://github.com/bene-ges/nemo_compatible/tree/main/scripts/nlp/en_spellmapper/dataset_preparation` + +2. [Optional] Convert training dataset to tarred files. + `convert_dataset_to_tarred.sh` + +3. Train spellchecking model. + `run_training.sh` + or + `run_training_tarred.sh` + +4. Run evaluation. + - [test_on_kensho.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh) + - [test_on_userlibri.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh) + - [test_on_spoken_wikipedia.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh) + +5. Run inference. + `python run_infer.sh` diff --git a/examples/nlp/spellchecking_asr_customization/checkpoint_to_nemo.py b/examples/nlp/spellchecking_asr_customization/checkpoint_to_nemo.py new file mode 100644 index 000000000000..c2f514f3e67e --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/checkpoint_to_nemo.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script converts checkpoint .ckpt to .nemo file. + +This script uses the `examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. +""" + +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.nlp.models import SpellcheckingAsrCustomizationModel +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config") +def main(cfg: DictConfig) -> None: + logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}') + SpellcheckingAsrCustomizationModel.load_from_checkpoint(cfg.checkpoint_path).save_to(cfg.target_nemo_path) + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml b/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml new file mode 100644 index 000000000000..c98915cdfc6f --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml @@ -0,0 +1,97 @@ +name: &name spellchecking +lang: ??? # e.g. 'ru', 'en' + +# Pretrained Nemo Models +pretrained_model: null + +trainer: + devices: 1 # the number of gpus, 0 for CPU + num_nodes: 1 + max_epochs: 3 # the number of training epochs + enable_checkpointing: false # provided by exp_manager + logger: false # provided by exp_manager + accumulate_grad_batches: 1 # accumulates grads every k batches + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + strategy: ddp + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + +model: + do_training: true + label_map: ??? # path/.../label_map.txt + semiotic_classes: ??? # path/.../semiotic_classes.txt + max_sequence_len: 128 + lang: ${lang} + hidden_size: 768 + + optim: + name: adamw + lr: 3e-5 + weight_decay: 0.1 + + sched: + name: WarmupAnnealing + + # pytorch lightning args + monitor: val_loss + reduce_on_plateau: false + + # scheduler config override + warmup_ratio: 0.1 + last_epoch: -1 + + language_model: + pretrained_model_name: bert-base-uncased # For ru, try DeepPavlov/rubert-base-cased | For de or multilingual, try bert-base-multilingual-cased + lm_checkpoint: null + config_file: null # json file, precedence over config + config: null + + tokenizer: + tokenizer_name: ${model.language_model.pretrained_model_name} # or sentencepiece + vocab_file: null # path to vocab file + tokenizer_model: null # only used if tokenizer is sentencepiece + special_tokens: null + +exp_manager: + exp_dir: nemo_experiments # where to store logs and checkpoints + name: training # name of experiment + create_tensorboard_logger: True + create_checkpoint_callback: True + checkpoint_callback_params: + save_top_k: 3 + monitor: "val_loss" + mode: "min" + +tokenizer: + tokenizer_name: ${model.transformer} # or sentencepiece + vocab_file: null # path to vocab file + tokenizer_model: null # only used if tokenizer is sentencepiece + special_tokens: null + +# Data +data: + train_ds: + data_path: ??? # provide the full path to the file + batch_size: 8 + shuffle: true + num_workers: 3 + pin_memory: false + drop_last: false + + validation_ds: + data_path: ??? # provide the full path to the file. + batch_size: 8 + shuffle: false + num_workers: 3 + pin_memory: false + drop_last: false + + +# Inference +inference: + from_file: null # Path to the raw text, no labels required. Each sentence on a separate line + out_file: null # Path to the output file + batch_size: 16 # batch size for inference.from_file diff --git a/examples/nlp/spellchecking_asr_customization/convert_data_to_tarred.sh b/examples/nlp/spellchecking_asr_customization/convert_data_to_tarred.sh new file mode 100644 index 000000000000..d4265eb4beb6 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/convert_data_to_tarred.sh @@ -0,0 +1,50 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Path to NeMo repository +NEMO_PATH=NeMo + +DATA_PATH="data_folder" + +## data_folder_example +## ├── tarred_data +## | └── (output) +## ├── config.json +##   ├── label_map.txt +##   ├── semiotic_classes.txt +## ├── test.tsv +## ├── 1.tsv +## ├── ... +## └── 200.tsv + +## Each of {1-200}.tsv input files are 110'000 examples subsets of all.tsv (except for validation part), +## generated by https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/dataset_preparation/build_training_data.sh +## Note that in this example we use 110'000 as input and only pack 100'000 of them to tar file. +## This is because some input examples, e.g. too long, can be skipped during preprocessing, and we want all tar files to contain fixed equal number of examples. + +for part in {1..200} +do + python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/create_tarred_dataset.py \ + lang="en" \ + data.train_ds.data_path=${DATA_PATH}/${part}.tsv \ + data.validation_ds.data_path=${DATA_PATH}/test.tsv \ + model.max_sequence_len=256 \ + model.language_model.pretrained_model_name=huawei-noah/TinyBERT_General_6L_768D \ + model.language_model.config_file=${DATA_PATH}/config.json \ + model.label_map=${DATA_PATH}/label_map.txt \ + model.semiotic_classes=${DATA_PATH}/semiotic_classes.txt \ + +output_tar_file=${DATA_PATH}/tarred_data/part${part}.tar \ + +take_first_n_lines=100000 +done diff --git a/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py new file mode 100644 index 000000000000..07d64ec5b723 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script is used to create an index of custom vocabulary and save it to file. +See "examples/nlp/spellchecking_asr_customization/run_infer.sh" for the whole inference pipeline. +""" + +from argparse import ArgumentParser + +from nemo.collections.nlp.data.spellchecking_asr_customization.utils import get_index, load_ngram_mappings + +parser = ArgumentParser(description="Create an index of custom vocabulary and save it to file") + +parser.add_argument( + "--input_name", required=True, type=str, help="Path to input file with custom vocabulary (plain text)" +) +parser.add_argument( + "--ngram_mappings", required=True, type=str, help="Path to input file with n-gram mapping vocabulary" +) +parser.add_argument("--output_name", required=True, type=str, help="Path to output file with custom vocabulary index") +parser.add_argument("--min_log_prob", default=-4.0, type=float, help="Threshold on log probability") +parser.add_argument( + "--max_phrases_per_ngram", + default=500, + type=int, + help="Threshold on number of phrases that can be stored for one n-gram key in index. Keys with more phrases are discarded.", +) +parser.add_argument( + "--max_misspelled_freq", default=125000, type=int, help="Threshold on maximum frequency of misspelled n-gram" +) + +args = parser.parse_args() + +# Load custom vocabulary +custom_phrases = set() +with open(args.input_name, "r", encoding="utf-8") as f: + for line in f: + phrase = line.strip() + custom_phrases.add(" ".join(list(phrase.replace(" ", "_")))) +print("Size of customization vocabulary:", len(custom_phrases)) + +# Load n-gram mappings vocabulary +ngram_mapping_vocab, ban_ngram = load_ngram_mappings(args.ngram_mappings, max_misspelled_freq=125000) + +# Generate index of custom phrases +phrases, ngram2phrases = get_index( + custom_phrases, + ngram_mapping_vocab, + ban_ngram, + min_log_prob=args.min_log_prob, + max_phrases_per_ngram=args.max_phrases_per_ngram, +) + +# Save index to file +with open(args.output_name, "w", encoding="utf-8") as out: + for ngram in ngram2phrases: + for phrase_id, begin, size, logprob in ngram2phrases[ngram]: + phrase = phrases[phrase_id] + out.write(ngram + "\t" + phrase + "\t" + str(begin) + "\t" + str(size) + "\t" + str(logprob) + "\n") diff --git a/examples/nlp/spellchecking_asr_customization/create_tarred_dataset.py b/examples/nlp/spellchecking_asr_customization/create_tarred_dataset.py new file mode 100644 index 000000000000..d0bdc2c9bd30 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/create_tarred_dataset.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script is used to create a tarred dataset for SpellcheckingAsrCustomizationModel. + +This script uses the `/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. Probably it is worth looking +at the example config file to see the list of parameters used for training. + +USAGE Example: +1. Obtain a processed dataset +2. Run: + python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/create_tarred_dataset.py \ + lang=${LANG} \ + data.train_ds.data_path=${DATA_PATH}/train.tsv \ + model.language_model.pretrained_model_name=${LANGUAGE_MODEL} \ + model.label_map=${DATA_PATH}/label_map.txt \ + +output_tar_file=tarred/part1.tar \ + +take_first_n_lines=100000 + +""" +import pickle +import tarfile +from io import BytesIO + +from helpers import MODEL, instantiate_model_and_trainer +from omegaconf import DictConfig, OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config") +def main(cfg: DictConfig) -> None: + logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') + logging.info("Start creating tar file from " + cfg.data.train_ds.data_path + " ...") + _, model = instantiate_model_and_trainer( + cfg, MODEL, True + ) # instantiate model like for training because we may not have pretrained model + dataset = model._train_dl.dataset + archive = tarfile.open(cfg.output_tar_file, mode="w") + max_lines = int(cfg.take_first_n_lines) + for i in range(len(dataset)): + if i >= max_lines: + logging.info("Reached " + str(max_lines) + " examples") + break + ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) = dataset[i] + + # do not store masks as they are just arrays of 1 + content = { + "input_ids": input_ids, + "input_mask": input_mask, + "segment_ids": segment_ids, + "input_ids_for_subwords": input_ids_for_subwords, + "input_mask_for_subwords": input_mask_for_subwords, + "segment_ids_for_subwords": segment_ids_for_subwords, + "character_pos_to_subword_pos": character_pos_to_subword_pos, + "labels_mask": labels_mask, + "labels": labels, + "spans": spans, + } + b = BytesIO() + pickle.dump(content, b) + b.seek(0) + tarinfo = tarfile.TarInfo(name="example_" + str(i) + ".pkl") + tarinfo.size = b.getbuffer().nbytes + archive.addfile(tarinfo=tarinfo, fileobj=b) + + archive.close() + logging.info("Tar file " + cfg.output_tar_file + " created!") + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/spellchecking_asr_customization/helpers.py b/examples/nlp/spellchecking_asr_customization/helpers.py new file mode 100644 index 000000000000..2db11b0e7d96 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/helpers.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Tuple + +import pytorch_lightning as pl +from omegaconf import DictConfig + +from nemo.collections.nlp.models import SpellcheckingAsrCustomizationModel +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import logging + +__all__ = ["MODEL", "MODEL_NAMES", "instantiate_model_and_trainer"] + +MODEL = "spellchecking" +MODEL_NAMES = [MODEL] + + +def instantiate_model_and_trainer( + cfg: DictConfig, model_name: str, do_training: bool +) -> Tuple[pl.Trainer, SpellcheckingAsrCustomizationModel]: + """ Function for instantiating a model and a trainer + Args: + cfg: The config used to instantiate the model and the trainer. + model_name: A str indicates the model direction, currently only 'itn'. + do_training: A boolean flag indicates whether the model will be trained or evaluated. + + Returns: + trainer: A PyTorch Lightning trainer + model: A SpellcheckingAsrCustomizationModel + """ + + if model_name not in MODEL_NAMES: + raise ValueError(f"{model_name} is unknown model type") + + # Get configs for the corresponding models + trainer_cfg = cfg.get("trainer") + model_cfg = cfg.get("model") + pretrained_cfg = cfg.get("pretrained_model", None) + trainer = pl.Trainer(**trainer_cfg) + if not pretrained_cfg: + logging.info(f"Initializing {model_name} model") + if model_name == MODEL: + model = SpellcheckingAsrCustomizationModel(model_cfg, trainer=trainer) + else: + raise ValueError(f"{model_name} is unknown model type") + elif os.path.exists(pretrained_cfg): + logging.info(f"Restoring pretrained {model_name} model from {pretrained_cfg}") + save_restore_connector = NLPSaveRestoreConnector() + model = SpellcheckingAsrCustomizationModel.restore_from( + pretrained_cfg, save_restore_connector=save_restore_connector + ) + else: + logging.info(f"Loading pretrained model {pretrained_cfg}") + if model_name == MODEL: + if pretrained_cfg not in SpellcheckingAsrCustomizationModel.get_available_model_names(): + raise ( + ValueError( + f"{pretrained_cfg} not in the list of available Tagger models." + f"Select from {SpellcheckingAsrCustomizationModel.list_available_models()}" + ) + ) + model = SpellcheckingAsrCustomizationModel.from_pretrained(pretrained_cfg) + else: + raise ValueError(f"{model_name} is unknown model type") + + # Setup train and validation data + if do_training: + model.setup_training_data(train_data_config=cfg.data.train_ds) + model.setup_validation_data(val_data_config=cfg.data.validation_ds) + + logging.info(f"Model {model_name} -- Device {model.device}") + return trainer, model diff --git a/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py b/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py new file mode 100644 index 000000000000..871d5e5c0c0c --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script is used to postprocess SpellMapper results and generate an updated nemo ASR manifest. +See "examples/nlp/spellchecking_asr_customization/run_infer.sh" for the whole inference pipeline. +""" + +from argparse import ArgumentParser + +from nemo.collections.nlp.data.spellchecking_asr_customization.utils import ( + update_manifest_with_spellmapper_corrections, +) + +parser = ArgumentParser(description="Postprocess SpellMapper results and generate an updated nemo ASR manifest") + +parser.add_argument("--input_manifest", required=True, type=str, help="Path to input nemo ASR manifest") +parser.add_argument( + "--field_name", default="pred_text", type=str, help="Name of json field with original ASR hypothesis text" +) +parser.add_argument( + "--short2full_name", + required=True, + type=str, + help="Path to input file with correspondence between sentence fragments and full sentences", +) +parser.add_argument( + "--spellmapper_results", required=True, type=str, help="Path to input file with SpellMapper inference results" +) +parser.add_argument("--output_manifest", required=True, type=str, help="Path to output nemo ASR manifest") +parser.add_argument("--min_prob", default=0.5, type=float, help="Threshold on replacement probability") +parser.add_argument( + "--use_dp", + action="store_true", + help="Whether to use additional replacement filtering by using dynamic programming", +) +parser.add_argument( + "--replace_hyphen_to_space", + action="store_true", + help="Whether to use space instead of hyphen in replaced fragments", +) +parser.add_argument( + "--ngram_mappings", type=str, required=True, help="File with ngram mappings, only needed if use_dp=true" +) +parser.add_argument( + "--min_dp_score_per_symbol", + default=-1.5, + type=float, + help="Minimum dynamic programming sum score averaged by hypothesis length", +) + +args = parser.parse_args() + +update_manifest_with_spellmapper_corrections( + input_manifest_name=args.input_manifest, + short2full_name=args.short2full_name, + output_manifest_name=args.output_manifest, + spellmapper_results_name=args.spellmapper_results, + min_prob=args.min_prob, + replace_hyphen_to_space=args.replace_hyphen_to_space, + field_name=args.field_name, + use_dp=args.use_dp, + ngram_mappings=args.ngram_mappings, + min_dp_score_per_symbol=args.min_dp_score_per_symbol, +) + +print("Resulting manifest saved to: ", args.output_manifest) diff --git a/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py b/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py new file mode 100644 index 000000000000..6fd5e524390a --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script contains an example on how to prepare input for SpellMapper inference from a nemo ASR manifest. +It splits sentences to shorter fragments, runs candidate retrieval and generates input in the required format. +It produces two output files: + 1. File with correspondence between sentence fragments and full sentences. + 2. File that will serve as input for SpellMapper inference. + +See "examples/nlp/spellchecking_asr_customization/run_infer.sh" for the whole inference pipeline. +""" + +from argparse import ArgumentParser + +from nemo.collections.nlp.data.spellchecking_asr_customization.utils import ( + extract_and_split_text_from_manifest, + get_candidates, + load_index, +) + +parser = ArgumentParser(description="Prepare input for SpellMapper inference from a nemo ASR manifest") +parser.add_argument("--manifest", required=True, type=str, help="Path to input manifest file") +parser.add_argument( + "--custom_vocab_index", required=True, type=str, help="Path to input file with custom vocabulary index" +) +parser.add_argument( + "--big_sample", + required=True, + type=str, + help="Path to input file with big sample of phrases to sample dummy candidates if there less than 10 are found by retrieval", +) +parser.add_argument( + "--short2full_name", + required=True, + type=str, + help="Path to output file with correspondence between sentence fragments and full sentences", +) +parser.add_argument( + "--output_name", + required=True, + type=str, + help="Path to output file that will serve as input for SpellMapper inference", +) +parser.add_argument("--field_name", default="pred_text", type=str, help="Name of json field with ASR hypothesis text") +parser.add_argument("--len_in_words", default=16, type=int, help="Maximum fragment length in words") +parser.add_argument( + "--step_in_words", + default=8, + type=int, + help="Step in words for moving to next fragment. If less than len_in_words, fragments will intersect", +) + +args = parser.parse_args() + +# Split ASR hypotheses to shorter fragments, because SpellMapper can't handle arbitrarily long sequences. +# The correspondence between short and original fragments is saved to a file and will be used at post-processing. +extract_and_split_text_from_manifest( + input_name=args.manifest, + output_name=args.short2full_name, + field_name=args.field_name, + len_in_words=args.len_in_words, + step_in_words=args.step_in_words, +) + +# Load index of custom vocabulary from file +phrases, ngram2phrases = load_index(args.custom_vocab_index) + +# Load big sample of phrases to sample dummy candidates if there less than 10 are found by retrieval +big_sample_of_phrases = set() +with open(args.big_sample, "r", encoding="utf-8") as f: + for line in f: + phrase, freq = line.strip().split("\t") + if int(freq) > 50: # do not want to use frequent phrases as dummy candidates + continue + if len(phrase) < 6 or len(phrase) > 15: # do not want to use too short or too long phrases as dummy candidates + continue + big_sample_of_phrases.add(phrase) + +big_sample_of_phrases = list(big_sample_of_phrases) + +# Generate input for SpellMapper inference +out = open(args.output_name, "w", encoding="utf-8") +with open(args.short2full_name, "r", encoding="utf-8") as f: + for line in f: + short_sent, _ = line.strip().split("\t") + sent = "_".join(short_sent.split()) + letters = list(sent) + candidates = get_candidates(ngram2phrases, phrases, letters, big_sample_of_phrases) + if len(candidates) == 0: + continue + if len(candidates) != 10: + raise ValueError("expect 10 candidates, got: ", len(candidates)) + + # We add two columns with targets and span_info. + # They have same format as during training, but start and end positions are APPROXIMATE, they will be adjusted when constructing BertExample. + targets = [] + span_info = [] + for idx, c in enumerate(candidates): + if c[1] == -1: + continue + targets.append(str(idx + 1)) # targets are 1-based + start = c[1] + # ensure that end is not outside sentence length (it can happen because c[2] is candidate length used as approximation) + end = min(c[1] + c[2], len(letters)) + span_info.append("CUSTOM " + str(start) + " " + str(end)) + out.write( + " ".join(letters) + + "\t" + + ";".join([x[0] for x in candidates]) + + "\t" + + " ".join(targets) + + "\t" + + ";".join(span_info) + + "\n" + ) +out.close() diff --git a/examples/nlp/spellchecking_asr_customization/run_infer.sh b/examples/nlp/spellchecking_asr_customization/run_infer.sh new file mode 100644 index 000000000000..09da98171c16 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/run_infer.sh @@ -0,0 +1,99 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +## RUN INFERENCE ON NEMO MANIFEST AND CUSTOM VOCABULARY + +## Path to NeMo repository +NEMO_PATH=NeMo + +## Download model repo from Hugging Face (if clone doesn't work, run "git lfs install" and try again) +git clone https://huggingface.co/bene-ges/spellmapper_asr_customization_en +## Download repo with test data +git clone https://huggingface.co/datasets/bene-ges/spellmapper_en_evaluation + +## Files in model repo +PRETRAINED_MODEL=spellmapper_asr_customization_en/training_10m_5ep.nemo +NGRAM_MAPPINGS=spellmapper_asr_customization_en/replacement_vocab_filt.txt +BIG_SAMPLE=spellmapper_asr_customization_en/big_sample.txt + +## Override these two files if you want to test on your own data +## File with input nemo ASR manifest +INPUT_MANIFEST=spellmapper_en_evaluation/medical_manifest_ctc.json +## File containing custom words and phrases (plain text) +CUSTOM_VOCAB=spellmapper_en_evaluation/medical_custom_vocab.json + +## Other files will be created +## File with index of custom vocabulary +INDEX="index.txt" +## File with short fragments and corresponding original sentences +SHORT2FULL="short2full.txt" +## File with input for SpellMapper inference +SPELLMAPPER_INPUT="spellmapper_input.txt" +## File with output of SpellMapper inference +SPELLMAPPER_OUTPUT="spellmapper_output.txt" +## File with output nemo ASR manifest +OUTPUT_MANIFEST="out_manifest.json" + + +# Create index of custom vocabulary +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py \ + --input_name ${CUSTOM_VOCAB} \ + --ngram_mappings ${NGRAM_MAPPINGS} \ + --output_name ${INDEX} \ + --min_log_prob -4.0 \ + --max_phrases_per_ngram 600 + +# Prepare input for SpellMapper inference +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py \ + --manifest ${INPUT_MANIFEST} \ + --custom_vocab_index ${INDEX} \ + --big_sample ${BIG_SAMPLE} \ + --short2full_name ${SHORT2FULL} \ + --output_name ${SPELLMAPPER_INPUT} \ + --field_name "pred_text" \ + --len_in_words 16 \ + --step_in_words 8 + +# Run SpellMapper inference +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \ + pretrained_model=${PRETRAINED_MODEL} \ + model.max_sequence_len=512 \ + inference.from_file=${SPELLMAPPER_INPUT} \ + inference.out_file=${SPELLMAPPER_OUTPUT} \ + inference.batch_size=16 \ + lang=en + +# Postprocess and create output corrected manifest +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py \ + --input_manifest ${INPUT_MANIFEST} \ + --short2full_name ${SHORT2FULL} \ + --output_manifest ${OUTPUT_MANIFEST} \ + --spellmapper_result ${SPELLMAPPER_OUTPUT} \ + --replace_hyphen_to_space \ + --field_name "pred_text" \ + --use_dp \ + --ngram_mappings ${NGRAM_MAPPINGS} \ + --min_dp_score_per_symbol -1.5 + +# Check WER of initial manifest +python ${NEMO_PATH}/examples/asr/speech_to_text_eval.py \ + dataset_manifest=${INPUT_MANIFEST} \ + use_cer=False \ + only_score_manifest=True + +# Check WER of corrected manifest +python ${NEMO_PATH}/examples/asr/speech_to_text_eval.py \ + dataset_manifest=${OUTPUT_MANIFEST} \ + use_cer=False \ + only_score_manifest=True diff --git a/examples/nlp/spellchecking_asr_customization/run_training.sh b/examples/nlp/spellchecking_asr_customization/run_training.sh new file mode 100644 index 000000000000..85dddbb2a038 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/run_training.sh @@ -0,0 +1,56 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +## TRAIN WITH NON-TARRED DATA + +# Path to NeMo repository +NEMO_PATH=NeMo + +## Download repo with training data (very small example) +## If clone doesn't work, run "git lfs install" and try again +git clone https://huggingface.co/datasets/bene-ges/spellmapper_en_train_micro + +DATA_PATH=spellmapper_en_train_micro + +## Example of all files needed to run training with non-tarred data: +## spellmapper_en_train_micro +## ├── config.json +##   ├── label_map.txt +##   ├── semiotic_classes.txt +## ├── test.tsv +## └── train.tsv + +## To generate files config.json, label_map.txt, semiotic_classes.txt - run generate_configs.sh +## Files "train.tsv" and "test.tsv" contain training examples. +## For data preparation see https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/dataset_preparation/build_training_data.sh + +## Note that training with non-tarred data only works on single gpu. It makes sense if you use 1-2 million examples or less. + +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py \ + lang="en" \ + data.validation_ds.data_path=${DATA_PATH}/test.tsv \ + data.train_ds.data_path=${DATA_PATH}/train.tsv \ + data.train_ds.batch_size=32 \ + data.train_ds.num_workers=8 \ + model.max_sequence_len=512 \ + model.language_model.pretrained_model_name=huawei-noah/TinyBERT_General_6L_768D \ + model.language_model.config_file=${DATA_PATH}/config.json \ + model.label_map=${DATA_PATH}/label_map.txt \ + model.semiotic_classes=${DATA_PATH}/semiotic_classes.txt \ + model.optim.lr=3e-5 \ + trainer.devices=[1] \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.strategy=ddp \ + trainer.max_epochs=5 diff --git a/examples/nlp/spellchecking_asr_customization/run_training_tarred.sh b/examples/nlp/spellchecking_asr_customization/run_training_tarred.sh new file mode 100644 index 000000000000..655c3e23e610 --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/run_training_tarred.sh @@ -0,0 +1,63 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +## TRAIN WITH TARRED DATA + +# Path to NeMo repository +NEMO_PATH=NeMo + +DATA_PATH=data_folder + +## data_folder_example +## ├── train_tarred +## | ├── part1.tar +## | ├── ... +## | └── part200.tar +## ├── config.json +##   ├── label_map.txt +##   ├── semiotic_classes.txt +## └── test.tsv +## To generate files config.json, label_map.txt, semiotic_classes.txt, run generate_configs.sh +## To prepare data, see ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/dataset_preparation/build_training_data.sh +## To convert data to tarred format, split all.tsv to pieces of 110'000 examples (except for validation part) and use ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/dataset_preparation/convert_data_to_tarred.sh +## To run training with tarred data, use ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/run_training_tarred.sh + +## ATTENTION: How to calculate model.optim.sched.max_steps: +## Suppose, you have 2'000'000 training examples, and want to train for 5 epochs on 4 gpus with batch size 32. +## 5 (epochs) * 32 (bs) * 4 (gpus) +## 1 step consumes 128 examples (32(bs) * 4(gpus)) +## 1 epoch makes 2000000/128=15625 steps (updates) +## 5 epochs make 5*15625=78125 steps + +python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py \ + lang="en" \ + data.validation_ds.data_path=${DATA_PATH}/test.tsv \ + data.train_ds.data_path=${DATA_PATH}/train_tarred/part_OP_1..100_CL_.tar \ + data.train_ds.batch_size=32 \ + data.train_ds.num_workers=16 \ + +data.train_ds.use_tarred_dataset=true \ + data.train_ds.shuffle=false \ + data.validation_ds.batch_size=16 \ + model.max_sequence_len=512 \ + model.language_model.pretrained_model_name=huawei-noah/TinyBERT_General_6L_768D \ + model.language_model.config_file=${DATA_PATH}/config.json \ + model.label_map=${DATA_PATH}/label_map.txt \ + model.semiotic_classes=${DATA_PATH}/semiotic_classes.txt \ + model.optim.sched.name=CosineAnnealing \ + +model.optim.sched.max_steps=195313 \ + trainer.devices=8 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.strategy=ddp \ + trainer.max_epochs=5 diff --git a/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py new file mode 100644 index 000000000000..593264f14a5d --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script contains an example on how to run inference with the SpellcheckingAsrCustomizationModel. + +An input line should consist of 4 tab-separated columns: + 1. text of ASR-hypothesis + 2. texts of 10 candidates separated by semicolon + 3. 1-based ids of non-dummy candidates + 4. approximate start/end coordinates of non-dummy candidates (correspond to ids in third column) + +Example input (in one line): + t h e _ t a r a s i c _ o o r d a _ i s _ a _ p a r t _ o f _ t h e _ a o r t a _ l o c a t e d _ i n _ t h e _ t h o r a x + h e p a t i c _ c i r r h o s i s;u r a c i l;c a r d i a c _ a r r e s t;w e a n;a p g a r;p s y c h o m o t o r;t h o r a x;t h o r a c i c _ a o r t a;a v f;b l o c k a d e d + 1 2 6 7 8 9 10 + CUSTOM 6 23;CUSTOM 4 10;CUSTOM 4 15;CUSTOM 56 62;CUSTOM 5 19;CUSTOM 28 31;CUSTOM 39 48 + +Each line in SpellMapper output is tab-separated and consists of 4 columns: + 1. ASR-hypothesis (same as in input) + 2. 10 candidates separated with semicolon (same as in input) + 3. fragment predictions, separated with semicolon, each prediction is a tuple (start, end, candidate_id, probability) + 4. letter predictions - candidate_id predicted for each letter (this is only for debug purposes) + +Example output (in one line): + t h e _ t a r a s i c _ o o r d a _ i s _ a _ p a r t _ o f _ t h e _ a o r t a _ l o c a t e d _ i n _ t h e _ t h o r a x + h e p a t i c _ c i r r h o s i s;u r a c i l;c a r d i a c _ a r r e s t;w e a n;a p g a r;p s y c h o m o t o r;t h o r a x;t h o r a c i c _ a o r t a;a v f;b l o c k a d e d + 56 62 7 0.99998;4 20 8 0.95181;12 20 8 0.44829;4 17 8 0.99464;12 17 8 0.97645 + 8 8 8 0 8 8 8 8 8 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 7 7 7 7 7 + + +USAGE Example: +1. Train a model, or use a pretrained checkpoint. +2. Run on a single file: + python nemo/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \ + pretrained_model=${PRETRAINED_NEMO_CHECKPOINT} \ + model.max_sequence_len=512 \ + inference.from_file=input.txt \ + inference.out_file=output.txt \ + inference.batch_size=16 \ + lang=en +or on multiple files: + python ${NEMO_PATH}/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \ + pretrained_model=${PRETRAINED_NEMO_CHECKPOINT} \ + model.max_sequence_len=512 \ + +inference.from_filelist=filelist.txt \ + +inference.output_folder=output_folder \ + inference.batch_size=16 \ + lang=en + +This script uses the `/examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. +""" + + +import os + +from helpers import MODEL, instantiate_model_and_trainer +from omegaconf import DictConfig, OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config") +def main(cfg: DictConfig) -> None: + logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}') + + if cfg.pretrained_model is None: + raise ValueError("A pre-trained model should be provided.") + _, model = instantiate_model_and_trainer(cfg, MODEL, False) + + if cfg.model.max_sequence_len != model.max_sequence_len: + model.max_sequence_len = cfg.model.max_sequence_len + model.builder._max_seq_length = cfg.model.max_sequence_len + input_filenames = [] + output_filenames = [] + + if "from_filelist" in cfg.inference and "output_folder" in cfg.inference: + filelist_file = cfg.inference.from_filelist + output_folder = cfg.inference.output_folder + with open(filelist_file, "r", encoding="utf-8") as f: + for line in f: + path = line.strip() + input_filenames.append(path) + folder, name = os.path.split(path) + output_filenames.append(os.path.join(output_folder, name)) + else: + text_file = cfg.inference.from_file + logging.info(f"Running inference on {text_file}...") + if not os.path.exists(text_file): + raise ValueError(f"{text_file} not found.") + input_filenames.append(text_file) + output_filenames.append(cfg.inference.out_file) + + dataloader_cfg = { + "batch_size": cfg.inference.get("batch_size", 8), + "num_workers": cfg.inference.get("num_workers", 4), + "pin_memory": cfg.inference.get("num_workers", False), + } + for input_filename, output_filename in zip(input_filenames, output_filenames): + if not os.path.exists(input_filename): + logging.info(f"Skip non-existing {input_filename}.") + continue + model.infer(dataloader_cfg, input_filename, output_filename) + logging.info(f"Predictions saved to {output_filename}.") + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py new file mode 100644 index 000000000000..7ea9314d196d --- /dev/null +++ b/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_train.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script contains an example on how to train SpellMapper (SpellcheckingAsrCustomizationModel). +It uses the `examples/nlp/spellchecking_asr_customization/conf/spellchecking_asr_customization_config.yaml` +config file by default. The other option is to set another config file via command +line arguments by `--config-name=CONFIG_FILE_PATH'. Probably it is worth looking +at the example config file to see the list of parameters used for training. + +USAGE Example: + See `examples/nlp/spellchecking_asr_customization/run_training.sh` for training on non-tarred data. + and + `examples/nlp/spellchecking_asr_customization/run_training_tarred.sh` for training on tarred data. + +One (non-tarred) training example should consist of 4 tab-separated columns: + 1. text of ASR-hypothesis + 2. texts of 10 candidates separated by semicolon + 3. 1-based ids of correct candidates, or 0 if none + 4. start/end coordinates of correct candidates (correspond to ids in third column) +Example (in one line): + a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o + d i d i e r _ s a u m o n;a s t r o n o m i e;t r i s t a n _ g u i l l o t;t r i s t e s s e;m o n a d e;c h r i s t i a n;a s t r o n o m e r;s o l o m o n;d i d i d i d i d i;m e r c y + 1 3 + CUSTOM 12 23;CUSTOM 28 41 +""" + +from helpers import MODEL, instantiate_model_and_trainer +from omegaconf import DictConfig, OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="spellchecking_asr_customization_config") +def main(cfg: DictConfig) -> None: + logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') + + # Train the model + if cfg.model.do_training: + logging.info( + "================================================================================================" + ) + logging.info('Start training...') + trainer, model = instantiate_model_and_trainer(cfg, MODEL, True) + spellchecking_exp_manager = cfg.get('exp_manager', None) + exp_manager(trainer, spellchecking_exp_manager) + trainer.fit(model) + logging.info('Training finished!') + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/text_normalization_as_tagging/dataset_preparation/extract_giza_alignments.py b/examples/nlp/text_normalization_as_tagging/dataset_preparation/extract_giza_alignments.py index e2ae48a37a0b..f5a53b1f331d 100644 --- a/examples/nlp/text_normalization_as_tagging/dataset_preparation/extract_giza_alignments.py +++ b/examples/nlp/text_normalization_as_tagging/dataset_preparation/extract_giza_alignments.py @@ -19,9 +19,14 @@ import re from argparse import ArgumentParser -from typing import List, Tuple -import numpy as np +from nemo.collections.nlp.data.text_normalization_as_tagging.utils import ( + check_monotonicity, + fill_alignment_matrix, + get_targets, + get_targets_from_back, +) + parser = ArgumentParser(description='Extract final alignments from GIZA++ alignments') parser.add_argument('--mode', type=str, required=True, help='tn or itn') @@ -34,211 +39,13 @@ args = parser.parse_args() -def fill_alignment_matrix( - fline2: str, fline3: str, gline2: str, gline3: str -) -> Tuple[np.ndarray, List[str], List[str]]: - """Parse Giza++ direct and reverse alignment results and represent them as an alignment matrix - - Args: - fline2: e.g. "_2 0 1 4_" - fline3: e.g. "NULL ({ }) twenty ({ 1 }) fourteen ({ 2 3 4 })" - gline2: e.g. "twenty fourteen" - gline3: e.g. "NULL ({ }) _2 ({ 1 }) 0 ({ }) 1 ({ }) 4_ ({ 2 })" - - Returns: - matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment - the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the - words were aligned only in one direction, 0 - no alignment. - srctokens: e.g. ["twenty", "fourteen"] - dsttokens: e.g. ["_2", "0", "1", "4_"] - - For example, the alignment matrix for the above example may look like: - [[3, 0, 0, 0] - [0, 2, 2, 3]] - """ - if fline2 is None or gline2 is None or fline3 is None or gline3 is None: - raise ValueError(f"empty params") - srctokens = gline2.split() - dsttokens = fline2.split() - pattern = r"([^ ]+) \(\{ ([^\(\{\}\)]*) \}\)" - src2dst = re.findall(pattern, fline3.replace("({ })", "({ })")) - dst2src = re.findall(pattern, gline3.replace("({ })", "({ })")) - if len(src2dst) != len(srctokens) + 1: - raise ValueError( - "length mismatch: len(src2dst)=" - + str(len(src2dst)) - + "; len(srctokens)" - + str(len(srctokens)) - + "\n" - + gline2 - + "\n" - + fline3 - ) - if len(dst2src) != len(dsttokens) + 1: - raise ValueError( - "length mismatch: len(dst2src)=" - + str(len(dst2src)) - + "; len(dsttokens)" - + str(len(dsttokens)) - + "\n" - + fline2 - + "\n" - + gline3 - ) - matrix = np.zeros((len(srctokens), len(dsttokens))) - for i in range(1, len(src2dst)): - token, to_str = src2dst[i] - if to_str == "": - continue - to = list(map(int, to_str.split())) - for t in to: - matrix[i - 1][t - 1] = 2 - - for i in range(1, len(dst2src)): - token, to_str = dst2src[i] - if to_str == "": - continue - to = list(map(int, to_str.split())) - for t in to: - matrix[t - 1][i - 1] += 1 - - return matrix, srctokens, dsttokens - - -def check_monotonicity(matrix: np.ndarray) -> bool: - """Check if alignment is monotonous - i.e. the relative order is preserved (no swaps). - - Args: - matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment - the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the - words were aligned only in one direction, 0 - no alignment. - """ - is_sorted = lambda k: np.all(k[:-1] <= k[1:]) - - a = np.argwhere(matrix == 3) - b = np.argwhere(matrix == 2) - c = np.vstack((a, b)) - d = c[c[:, 1].argsort()] # sort by second column (less important) - d = d[d[:, 0].argsort(kind="mergesort")] - return is_sorted(d[:, 1]) - - -def get_targets(matrix: np.ndarray, dsttokens: List[str]) -> List[str]: - """Join some of the destination tokens, so that their number becomes the same as the number of input words. - Unaligned tokens tend to join to the left aligned token. - - Args: - matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment - the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the - words were aligned only in one direction, 0 - no alignment. - dsttokens: e.g. ["_2", "0", "1", "4_"] - Returns: - targets: list of string tokens, with one-to-one correspondence to matrix.shape[0] - - Example: - If we get - matrix=[[3, 0, 0, 0] - [0, 2, 2, 3]] - dsttokens=["_2", "0", "1", "4_"] - it gives - targets = ["_201", "4_"] - Actually, this is a mistake instead of ["_20", "14_"]. That will be further corrected by regular expressions. - """ - targets = [] - last_covered_dst_id = -1 - for i in range(len(matrix)): - dstlist = [] - for j in range(last_covered_dst_id + 1, len(dsttokens)): - # matrix[i][j] == 3: safe alignment point - if matrix[i][j] == 3 or ( - j == last_covered_dst_id + 1 - and np.all(matrix[i, :] == 0) # if the whole line does not have safe points - and np.all(matrix[:, j] == 0) # and the whole column does not have safe points, match them - ): - if len(targets) == 0: # if this is first safe point, attach left unaligned columns to it, if any - for k in range(0, j): - if np.all(matrix[:, k] == 0): # if column k does not have safe points - dstlist.append(dsttokens[k]) - else: - break - dstlist.append(dsttokens[j]) - last_covered_dst_id = j - for k in range(j + 1, len(dsttokens)): - if np.all(matrix[:, k] == 0): # if column k does not have safe points - dstlist.append(dsttokens[k]) - last_covered_dst_id = k - else: - break - - if len(dstlist) > 0: - if args.mode == "tn": - targets.append("_".join(dstlist)) - else: - targets.append("".join(dstlist)) - else: - targets.append("") - return targets - - -def get_targets_from_back(matrix: np.ndarray, dsttokens: List[str]) -> List[str]: - """Join some of the destination tokens, so that their number becomes the same as the number of input words. - Unaligned tokens tend to join to the right aligned token. - - Args: - matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment - the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the - words were aligned only in one direction, 0 - no alignment. - dsttokens: e.g. ["_2", "0", "1", "4_"] - Returns: - targets: list of string tokens, with one-to-one correspondence to matrix.shape[0] - - Example: - If we get - matrix=[[3, 0, 0, 0] - [0, 2, 2, 3]] - dsttokens=["_2", "0", "1", "4_"] - it gives - targets = ["_2", "014_"] - Actually, this is a mistake instead of ["_20", "14_"]. That will be further corrected by regular expressions. - """ - - targets = [] - last_covered_dst_id = len(dsttokens) - for i in range(len(matrix) - 1, -1, -1): - dstlist = [] - for j in range(last_covered_dst_id - 1, -1, -1): - if matrix[i][j] == 3 or ( - j == last_covered_dst_id - 1 and np.all(matrix[i, :] == 0) and np.all(matrix[:, j] == 0) - ): - if len(targets) == 0: - for k in range(len(dsttokens) - 1, j, -1): - if np.all(matrix[:, k] == 0): - dstlist.append(dsttokens[k]) - else: - break - dstlist.append(dsttokens[j]) - last_covered_dst_id = j - for k in range(j - 1, -1, -1): - if np.all(matrix[:, k] == 0): - dstlist.append(dsttokens[k]) - last_covered_dst_id = k - else: - break - if len(dstlist) > 0: - if args.mode == "tn": - targets.append("_".join(list(reversed(dstlist)))) - else: - targets.append("".join(list(reversed(dstlist)))) - else: - targets.append("") - return list(reversed(targets)) - - def main() -> None: g = open(args.giza_dir + "/GIZA++." + args.giza_suffix, "r", encoding="utf-8") f = open(args.giza_dir + "/GIZA++reverse." + args.giza_suffix, "r", encoding="utf-8") + target_inner_delimiter = "" if args.mode == "tn": g, f = f, g + target_inner_delimiter = "_" out = open(args.giza_dir + "/" + args.out_filename, "w", encoding="utf-8") cache = {} good_count, not_mono_count, not_covered_count, exception_count = 0, 0, 0, 0 @@ -277,8 +84,8 @@ def main() -> None: else: matrix[matrix <= 2] = 0 # leave only 1-to-1 alignment points if check_monotonicity(matrix): - targets = get_targets(matrix, dsttokens) - targets_from_back = get_targets_from_back(matrix, dsttokens) + targets = get_targets(matrix, dsttokens, delimiter=target_inner_delimiter) + targets_from_back = get_targets_from_back(matrix, dsttokens, delimiter=target_inner_delimiter) if len(targets) != len(srctokens): raise ValueError( "targets length doesn't match srctokens length: len(targets)=" diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/__init__.py b/nemo/collections/nlp/data/spellchecking_asr_customization/__init__.py new file mode 100644 index 000000000000..4e786276108c --- /dev/null +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.nlp.data.spellchecking_asr_customization.dataset import ( + SpellcheckingAsrCustomizationDataset, + SpellcheckingAsrCustomizationTestDataset, + TarredSpellcheckingAsrCustomizationDataset, +) diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/bert_example.py b/nemo/collections/nlp/data/spellchecking_asr_customization/bert_example.py new file mode 100644 index 000000000000..803d0eaf8aed --- /dev/null +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/bert_example.py @@ -0,0 +1,593 @@ +# Copyright 2019 The Google Research Authors. +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import OrderedDict +from os import path +from typing import Dict, List, Optional, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +"""Build BERT Examples from asr hypothesis, customization candidates, target labels, span info. +""" + + +class BertExample(object): + """Class for training and inference examples for BERT. + + Attributes: + features: Feature dictionary. + """ + + def __init__( + self, + input_ids: List[int], + input_mask: List[int], + segment_ids: List[int], + input_ids_for_subwords: List[int], + input_mask_for_subwords: List[int], + segment_ids_for_subwords: List[int], + character_pos_to_subword_pos: List[int], + fragment_indices: List[Tuple[int, int, int]], + labels_mask: List[int], + labels: List[int], + spans: List[Tuple[int, int, int]], + default_label: int, + ) -> None: + """Inputs to the example wrapper + + Args: + input_ids: indices of single characters (treated as subwords) + input_mask: list of bools with 0s in place of input_ids to be masked + segment_ids: list of ints from 0 to 10 to denote the text segment type ( + 0 - for tokens of ASR hypothesis, + 1 - for tokens of the first candidate + ... + 10 - for tokens of the tenth candidate + ) + input_ids_for_subwords: indices of real subwords (as tokenized by bert tokenizer) + input_mask_for_subwords: list of bools with 0s in place of input_ids_for_subwords to be masked + segment_ids_for_subwords: same as segment_ids but for input_ids_for_subwords + character_pos_to_subword_pos: list of size=len(input_ids), value=(position of corresponding subword in input_ids_for_subwords) + fragment_indices: list of tuples (start_position, end_position, candidate_id), end is exclusive, candidate_id can be -1 if not set + labels_mask: bool tensor with 0s in place of label tokens to be masked + labels: indices of semiotic classes which should be predicted from each of the + corresponding input tokens + spans: list of tuples (class_id, start_position, end_position), end is exclusive, class is always 1(CUSTOM) + default_label: The default label + """ + input_len = len(input_ids) + if not ( + input_len == len(input_mask) + and input_len == len(segment_ids) + and input_len == len(labels_mask) + and input_len == len(labels) + and input_len == len(character_pos_to_subword_pos) + ): + raise ValueError("All feature lists should have the same length ({})".format(input_len)) + + input_len_for_subwords = len(input_ids_for_subwords) + if not ( + input_len_for_subwords == len(input_mask_for_subwords) + and input_len_for_subwords == len(segment_ids_for_subwords) + ): + raise ValueError( + "All feature lists for subwords should have the same length ({})".format(input_len_for_subwords) + ) + + self.features = OrderedDict( + [ + ("input_ids", input_ids), + ("input_mask", input_mask), + ("segment_ids", segment_ids), + ("input_ids_for_subwords", input_ids_for_subwords), + ("input_mask_for_subwords", input_mask_for_subwords), + ("segment_ids_for_subwords", segment_ids_for_subwords), + ("character_pos_to_subword_pos", character_pos_to_subword_pos), + ("fragment_indices", fragment_indices), + ("labels_mask", labels_mask), + ("labels", labels), + ("spans", spans), + ] + ) + self._default_label = default_label + + +class BertExampleBuilder(object): + """Builder class for BertExample objects.""" + + def __init__( + self, + label_map: Dict[str, int], + semiotic_classes: Dict[str, int], + tokenizer: PreTrainedTokenizerBase, + max_seq_length: int, + ) -> None: + """Initializes an instance of BertExampleBuilder. + + Args: + label_map: Mapping from tags to tag IDs. + semiotic_classes: Mapping from semiotic classes to their ids. + tokenizer: Tokenizer object. + max_seq_length: Maximum sequence length. + """ + self._label_map = label_map + self._semiotic_classes = semiotic_classes + self._tokenizer = tokenizer + self._max_seq_length = max_seq_length + # one span usually covers one or more words and it only exists for custom phrases, so there are much less spans than characters. + self._max_spans_length = max(4, int(max_seq_length / 20)) + self._pad_id = self._tokenizer.pad_token_id + self._default_label = 0 + + def build_bert_example( + self, hyp: str, ref: str, target: Optional[str] = None, span_info: Optional[str] = None, infer: bool = False + ) -> Optional[BertExample]: + """Constructs a BERT Example. + + Args: + hyp: Hypothesis text. + ref: Candidate customization variants divided by ';' + target: + if infer==False, string of labels (each label is 1-based index of correct candidate) or 0. + if infer==True, it can be None or string of labels (each label is 1-based index of some candidate). In inference this can be used to get corresponding fragments to fragment_indices. + span_info: + string of format "CUSTOM 6 20;CUSTOM 40 51", number of parts corresponds to number of targets. Can be empty if target is 0. + If infer==False, numbers are correct start and end(exclusive) positions of the corresponding target candidate in the text. + If infer==True, numbers are EXPECTED positions in the text. In inference this can be used to get corresponding fragments to fragment_indices. + infer: inference mode + Returns: + BertExample, or None if the conversion from text to tags was infeasible + + Example (infer=False): + hyp: "a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o" + ref: "d i d i e r _ s a u m o n;a s t r o n o m i e;t r i s t a n _ g u i l l o t;t r i s t e s s e;m o n a d e;c h r i s t i a n;a s t r o n o m e r;s o l o m o n;d i d i d i d i d i;m e r c y" + target: "1 3" + span_info: "CUSTOM 12 23;CUSTOM 28 41" + """ + if not ref.count(";") == 9: + raise ValueError("Expect 10 candidates: " + ref) + + span_info_parts = [] + targets = [] + + if len(target) > 0 and target != "0": + span_info_parts = span_info.split(";") + targets = list(map(int, target.split(" "))) + if len(span_info_parts) != len(targets): + raise ValueError( + "len(span_info_parts)=" + + str(len(span_info_parts)) + + " is different from len(target_parts)=" + + str(len(targets)) + ) + + tags = [0 for _ in hyp.split()] + if not infer: + for p, t in zip(span_info_parts, targets): + c, start, end = p.split(" ") + start = int(start) + end = int(end) + tags[start:end] = [t for i in range(end - start)] + + # get input features for characters + (input_ids, input_mask, segment_ids, labels_mask, labels, _, _,) = self._get_input_features( + hyp=hyp, ref=ref, tags=tags + ) + + # get input features for words + hyp_with_words = hyp.replace(" ", "").replace("_", " ") + ref_with_words = ref.replace(" ", "").replace("_", " ") + ( + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + _, + _, + _, + _, + ) = self._get_input_features(hyp=hyp_with_words, ref=ref_with_words, tags=None) + + # used in forward to concatenate subword embeddings to character embeddings + character_pos_to_subword_pos = self._map_characters_to_subwords(input_ids, input_ids_for_subwords) + + fragment_indices = [] + if infer: + # used in inference to take argmax over whole fragments instead of separate characters to get more consistent predictions + fragment_indices = self._get_fragment_indices(hyp, targets, span_info_parts) + + spans = [] + if not infer: + # during training spans are used in validation step to calculate accuracy on whole custom phrases instead of separate characters + spans = self._get_spans(span_info_parts) + + if len(input_ids) > self._max_seq_length or len(spans) > self._max_spans_length: + print( + "Max len exceeded: len(input_ids)=", + len(input_ids), + "; _max_seq_length=", + self._max_seq_length, + "; len(spans)=", + len(spans), + "; _max_spans_length=", + self._max_spans_length, + ) + return None + + example = BertExample( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + input_ids_for_subwords=input_ids_for_subwords, + input_mask_for_subwords=input_mask_for_subwords, + segment_ids_for_subwords=segment_ids_for_subwords, + character_pos_to_subword_pos=character_pos_to_subword_pos, + fragment_indices=fragment_indices, + labels_mask=labels_mask, + labels=labels, + spans=spans, + default_label=self._default_label, + ) + return example + + def _get_spans(self, span_info_parts: List[str]) -> List[Tuple[int, int, int]]: + """ Converts span_info string into a list of (class_id, start, end) where start, end are coordinates of starting and ending(exclusive) tokens in input_ids of BertExample + + Example: + span_info_parts: ["CUSTOM 37 41", "CUSTOM 47 52", "CUSTOM 42 46", "CUSTOM 0 7"] + result: [(1, 38, 42), (1, 48, 53), (1, 43, 47), (1, 1, 8)] + """ + result_spans = [] + + for p in span_info_parts: + if p == "": + break + c, start, end = p.split(" ") + if c not in self._semiotic_classes: + raise KeyError("class=" + c + " not found in self._semiotic_classes") + cid = self._semiotic_classes[c] + # +1 because this should be indexing on input_ids which has [CLS] token at beginning + start = int(start) + 1 + end = int(end) + 1 + result_spans.append((cid, start, end)) + return result_spans + + def _get_fragment_indices( + self, hyp: str, targets: List[int], span_info_parts: List[str] + ) -> Tuple[List[Tuple[int, int, int]]]: + """ Build fragment indices for real candidates. + This is used only at inference. + After external candidate retrieval we know approximately, where the candidate is located in the text (from the positions of matched n-grams). + In this function we + 1) adjust start/end positions to match word borders (possibly in multiple ways). + 2) generate content for fragment_indices tensor (it will be used during inference to average all predictions inside each fragment). + + Args: + hyp: ASR-hypothesis where space separates single characters (real space is replaced to underscore). + targets: list of candidate ids (only for real candidates, not dummy) + span_info_parts: list of strings of format like "CUSTOM 12 25", corresponding to each of targets, with start/end coordinates in text. + Returns: + List of tuples (start, end, target) where start and end are positions in ASR-hypothesis, target is candidate_id. + Note that returned fragments can be unsorted and can overlap, it's ok. + Example: + hyp: "a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o" + targets: [1 2 3 4 6 7 9] + span_info_parts: ["CUSTOM 12 25", "CUSTOM 0 10", "CUSTOM 27 42", ...], where numbers are EXPECTED start/end positions of corresponding target candidates in the text. These positions will be adjusted in this functuion. + fragment_indices: [(1, 12, 2), (13, 24, 1), (13, 28, 1), ..., (29, 42, 3)] + """ + + fragment_indices = [] + + letters = hyp.split() + + for target, p in zip(targets, span_info_parts): + _, start, end = p.split(" ") + start = int(start) + end = min(int(end), len(hyp)) # guarantee that end is not outside length + + # Adjusting strategy 1: expand both sides to the nearest space. + # Adjust start by finding the nearest left space or beginning of text. If start is already some word beginning, it won't change. + k = start + while k > 0 and letters[k] != '_': + k -= 1 + adjusted_start = k if k == 0 else k + 1 + + # Adjust end by finding the nearest right space. If end is already space or sentence end, it won't change. + k = end + while k < len(letters) and letters[k] != '_': + k += 1 + adjusted_end = k + + # +1 because this should be indexing on input_ids which has [CLS] token at beginning + fragment_indices.append((adjusted_start + 1, adjusted_end + 1, target)) + + # Adjusting strategy 2: try to shrink to the closest space (from left or right or both sides). + # For example, here the candidate "shippers" has a matching n-gram covering part of previous word + # a b o u t _ o u r _ s h i p e r s _ b u t _ y o u _ k n o w + # 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 + expanded_fragment = "".join(letters[adjusted_start:adjusted_end]) + left_space_position = expanded_fragment.find("_") + right_space_position = expanded_fragment.rfind("_") + is_left_shrink = False + is_right_shrink = False + if left_space_position > -1 and left_space_position < len(expanded_fragment) / 2: + # +1 because of CLS token, another +1 to put start position after found space + fragment_indices.append((adjusted_start + 1 + left_space_position + 1, adjusted_end + 1, target)) + is_left_shrink = True + if right_space_position > -1 and right_space_position > len(expanded_fragment) / 2: + fragment_indices.append((adjusted_start + 1, adjusted_start + 1 + right_space_position, target)) + is_right_shrink = True + if is_left_shrink and is_right_shrink: + fragment_indices.append( + (adjusted_start + 1 + left_space_position + 1, adjusted_start + 1 + right_space_position, target) + ) + + return fragment_indices + + def _map_characters_to_subwords(self, input_ids: List[int], input_ids_for_subwords: List[int]) -> List[int]: + """ Maps each single character to the position of its corresponding subword. + + Args: + input_ids: List of character token ids. + input_ids_for_subwords: List of subword token ids. + Returns: + List of subword positions in input_ids_for_subwords. Its length is equal to len(input_ids) + + Example: + input_ids: [101, 1037, 1055, 1056, 1054, 1051, 1050, ..., 1051, 102, 1040, ..., 1050, 102, 1037, ..., 1041, 102, ..., 102] + input_ids_for_subwords: [101, 26357, 2106, 2666, 2061, 8202, 1998, 13012, 16643, 2319, 1043, 7174, 102, 2106, 3771, 7842, 2819, 2239, 102, ..., 102] + result: [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, ... , 45, 46, 46, 46, 46, 46, 47] + """ + character_pos_to_subword_pos = [0 for _ in input_ids] + + ## '[CLS]', 'a', 's', 't', 'r', 'o', 'n', 'o', 'm', 'e', 'r', 's', '_', 'd', 'i', ..., 'l', 'o', '[SEP]', 'd', 'i', 'd', 'i', 'e', 'r', '_', 's', 'a', 'u', 'm', 'o', 'n', ..., '[SEP]' + tokens = self._tokenizer.convert_ids_to_tokens(input_ids) + ## '[CLS]', 'astronomers', 'did', '##ie', 'so', '##mon', 'and', 'tri', '##sti', '##an', 'g', '##llo', '[SEP]', 'did', '##ier', 'sa', '##um', '##on', '[SEP]', 'astro', '##no', '##mie', '[SEP]', 'tristan', 'gui', '##llo', '##t', '[SEP]', ..., '[SEP]', 'mercy', '[SEP]'] + tokens_for_subwords = self._tokenizer.convert_ids_to_tokens(input_ids_for_subwords) + j = 0 # index for tokens_for_subwords + j_offset = 0 # current letter index within subword + for i in range(len(tokens)): + character = tokens[i] + subword = tokens_for_subwords[j] + if character == "[CLS]" and subword == "[CLS]": + character_pos_to_subword_pos[i] = j + j += 1 + continue + if character == "[SEP]" and subword == "[SEP]": + character_pos_to_subword_pos[i] = j + j += 1 + continue + if character == "[CLS]" or character == "[SEP]" or subword == "[CLS]" or subword == "[SEP]": + raise IndexError( + "character[" + + str(i) + + "]=" + + character + + "; subword[" + + str(j) + + ";=" + + subword + + "subwords=" + + str(tokens_for_subwords) + ) + # At this point we expect that + # subword either 1) is a normal first token of a word or 2) starts with "##" (not first word token) + # character either 1) is a normal character or 2) is a space character "_" + if character == "_": + character_pos_to_subword_pos[i] = j - 1 # space is assigned to previous subtoken + continue + if j_offset < len(subword): + if character == subword[j_offset]: + character_pos_to_subword_pos[i] = j + j_offset += 1 + else: + raise IndexError( + "character mismatch:" + + "i=" + + str(i) + + "j=" + + str(j) + + "j_offset=" + + str(j_offset) + + "; len(tokens)=" + + str(len(tokens)) + + "; len(subwords)=" + + str(len(tokens_for_subwords)) + ) + # if subword is finished, increase j + if j_offset >= len(subword): + j += 1 + j_offset = 0 + if j >= len(tokens_for_subwords): + break + if tokens_for_subwords[j].startswith("##"): + j_offset = 2 + # check that all subword tokens are processed + if j < len(tokens_for_subwords): + raise IndexError( + "j=" + + str(j) + + "; len(tokens)=" + + str(len(tokens)) + + "; len(subwords)=" + + str(len(tokens_for_subwords)) + ) + return character_pos_to_subword_pos + + def _get_input_features( + self, hyp: str, ref: str, tags: List[int] + ) -> Tuple[List[int], List[int], List[int], List[int], List[int], List[str], List[int]]: + """Converts given ASR-hypothesis(hyp) and candidate string(ref) to features(token ids, mask, segment ids, etc). + + Args: + hyp: Hypothesis text. + ref: Candidate customization variants divided by ';' + tags: List of labels corresponding to each token of ASR-hypothesis or None when building an example during inference. + Returns: + Features (input_ids, input_mask, segment_ids, labels_mask, labels, hyp_tokens, token_start_indices) + + Note that this method is called both for character-based example and for word-based example (to split to subwords). + + Character-based example: + hyp: "a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o" + ref: "d i d i e r _ s a u m o n;a s t r o n o m i e;t r i s t a n _ g u i l l o t;t r i s t e s s e;m o n a d e;c h r i s t i a n;a s t r o n o m e r;s o l o m o n;d i d i d i d i d i;m e r c y" + tags: "0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3" + + resulting token sequence: + '[CLS]', 'a', 's', 't', 'r', 'o', 'n', 'o', 'm', 'e', 'r', 's', '_', 'd', 'i', ..., 'l', 'o', '[SEP]', 'd', 'i', 'd', 'i', 'e', 'r', '_', 's', 'a', 'u', 'm', 'o', 'n', ..., '[SEP]' + + Word-based example: + hyp: "astronomers didie somon and tristian gllo" + ref: "didier saumon;astronomie;tristan guillot;tristesse;monade;christian;astronomer;solomon;dididididi;mercy" + tags: None (not used for word-based case) + + resulting token sequence: + '[CLS]', 'astronomers', 'did', '##ie', 'so', '##mon', 'and', 'tri', '##sti', '##an', 'g', '##llo', '[SEP]', 'did', '##ier', 'sa', '##um', '##on', '[SEP]', 'astro', '##no', '##mie', '[SEP]', 'tristan', 'gui', '##llo', '##t', '[SEP]', ..., '[SEP]', 'mercy', '[SEP]'] + """ + + labels_mask = [] + labels = [] + if tags is None: + hyp_tokens, token_start_indices = self._split_to_wordpieces(hyp.split()) + else: + hyp_tokens, labels, token_start_indices = self._split_to_wordpieces_with_labels(hyp.split(), tags) + references = ref.split(";") + all_ref_tokens = [] + all_ref_segment_ids = [] + for i in range(len(references)): + ref_tokens, _ = self._split_to_wordpieces(references[i].split()) + all_ref_tokens.extend(ref_tokens + ["[SEP]"]) + all_ref_segment_ids.extend([i + 1] * (len(ref_tokens) + 1)) + + input_tokens = ["[CLS]"] + hyp_tokens + ["[SEP]"] + all_ref_tokens # ends with [SEP] + input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens) + input_mask = [1] * len(input_ids) + segment_ids = [0] + [0] * len(hyp_tokens) + [0] + all_ref_segment_ids + if len(input_ids) != len(segment_ids): + raise ValueError( + "len(input_ids)=" + + str(len(input_ids)) + + " is different from len(segment_ids)=" + + str(len(segment_ids)) + ) + + if tags: + labels_mask = [0] + [1] * len(labels) + [0] + [0] * len(all_ref_tokens) + labels = [0] + labels + [0] + [0] * len(all_ref_tokens) + return (input_ids, input_mask, segment_ids, labels_mask, labels, hyp_tokens, token_start_indices) + + def _split_to_wordpieces_with_labels( + self, tokens: List[str], labels: List[int] + ) -> Tuple[List[str], List[int], List[int]]: + """Splits tokens (and the labels accordingly) to WordPieces. + + Args: + tokens: Tokens to be split. + labels: Labels (one per token) to be split. + + Returns: + 3-tuple with the split tokens, split labels, and the indices of starting tokens of words + """ + bert_tokens = [] # Original tokens split into wordpieces. + bert_labels = [] # Label for each wordpiece. + # Index of each wordpiece that starts a new token. + token_start_indices = [] + for i, token in enumerate(tokens): + # '+ 1' is because bert_tokens will be prepended by [CLS] token later. + token_start_indices.append(len(bert_tokens) + 1) + pieces = self._tokenizer.tokenize(token) + bert_tokens.extend(pieces) + bert_labels.extend([labels[i]] * len(pieces)) + return bert_tokens, bert_labels, token_start_indices + + def _split_to_wordpieces(self, tokens: List[str]) -> Tuple[List[str], List[int]]: + """Splits tokens to WordPieces. + + Args: + tokens: Tokens to be split. + + Returns: + tuple with the split tokens, and the indices of the WordPieces that start a token. + """ + bert_tokens = [] # Original tokens split into wordpieces. + # Index of each wordpiece that starts a new token. + token_start_indices = [] + for i, token in enumerate(tokens): + # '+ 1' is because bert_tokens will be prepended by [CLS] token later. + token_start_indices.append(len(bert_tokens) + 1) + pieces = self._tokenizer.tokenize(token) + bert_tokens.extend(pieces) + return bert_tokens, token_start_indices + + def read_input_file( + self, input_filename: str, infer: bool = False + ) -> Union[List['BertExample'], Tuple[List['BertExample'], Tuple[str, str]]]: + """Reads in Tab Separated Value file and converts to training/inference-ready examples. + + Args: + example_builder: Instance of BertExampleBuilder + input_filename: Path to the TSV input file. + infer: If true, input examples do not contain target info. + + Returns: + examples: List of converted examples (BertExample). + or + (examples, hyps_refs): If infer==true, returns h + """ + + if not path.exists(input_filename): + raise ValueError("Cannot find file: " + input_filename) + examples = [] # output list of BertExample + hyps_refs = [] # output list of tuples (ASR-hypothesis, candidate_str) + with open(input_filename, 'r') as f: + for line in f: + if len(examples) % 1000 == 0: + logging.info("{} examples processed.".format(len(examples))) + if infer: + parts = line.rstrip('\n').split('\t') + hyp, ref, target, span_info = parts[0], parts[1], None, None + if len(parts) == 4: + target, span_info = parts[2], parts[3] + try: + example = self.build_bert_example(hyp, ref, target=target, span_info=span_info, infer=infer) + except Exception as e: + logging.warning(str(e)) + logging.warning(line) + continue + if example is None: + logging.info("cannot create example: ") + logging.info(line) + continue + hyps_refs.append((hyp, ref)) + examples.append(example) + else: + hyp, ref, target, semiotic_info = line.rstrip('\n').split('\t') + try: + example = self.build_bert_example( + hyp, ref, target=target, span_info=semiotic_info, infer=infer + ) + except Exception as e: + logging.warning(str(e)) + logging.warning(line) + continue + if example is None: + logging.info("cannot create example: ") + logging.info(line) + continue + examples.append(example) + logging.info(f"Done. {len(examples)} examples converted.") + if infer: + return examples, hyps_refs + return examples diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/dataset.py b/nemo/collections/nlp/data/spellchecking_asr_customization/dataset.py new file mode 100644 index 000000000000..69705ec21b9d --- /dev/null +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/dataset.py @@ -0,0 +1,521 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pickle +from io import BytesIO +from typing import Dict, List, Optional, Tuple + +import braceexpand +import numpy as np +import torch +import webdataset as wd + +from nemo.collections.nlp.data.spellchecking_asr_customization.bert_example import BertExampleBuilder +from nemo.core.classes.dataset import Dataset, IterableDataset +from nemo.core.neural_types import ChannelType, IntType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + +__all__ = [ + "SpellcheckingAsrCustomizationDataset", + "SpellcheckingAsrCustomizationTestDataset", + "TarredSpellcheckingAsrCustomizationDataset", +] + + +def collate_train_dataset( + batch: List[ + Tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ] + ], + pad_token_id: int, +) -> Tuple[ + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, +]: + """collate batch of training items + Args: + batch: A list of tuples of (input_ids, input_mask, segment_ids, input_ids_for_subwords, input_mask_for_subwords, segment_ids_for_subwords, character_pos_to_subword_pos, labels_mask, labels, spans). + pad_token_id: integer id of padding token (to use in padded_input_ids, padded_input_ids_for_subwords) + """ + max_length = 0 + max_length_for_subwords = 0 + max_length_for_spans = 1 # to avoid empty tensor + for ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) in batch: + if len(input_ids) > max_length: + max_length = len(input_ids) + if len(input_ids_for_subwords) > max_length_for_subwords: + max_length_for_subwords = len(input_ids_for_subwords) + if len(spans) > max_length_for_spans: + max_length_for_spans = len(spans) + + padded_input_ids = [] + padded_input_mask = [] + padded_segment_ids = [] + padded_input_ids_for_subwords = [] + padded_input_mask_for_subwords = [] + padded_segment_ids_for_subwords = [] + padded_character_pos_to_subword_pos = [] + padded_labels_mask = [] + padded_labels = [] + padded_spans = [] + for ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) in batch: + if len(input_ids) < max_length: + pad_length = max_length - len(input_ids) + padded_input_ids.append(np.pad(input_ids, pad_width=[0, pad_length], constant_values=pad_token_id)) + padded_input_mask.append(np.pad(input_mask, pad_width=[0, pad_length], constant_values=0)) + padded_segment_ids.append(np.pad(segment_ids, pad_width=[0, pad_length], constant_values=0)) + padded_labels_mask.append(np.pad(labels_mask, pad_width=[0, pad_length], constant_values=0)) + padded_labels.append(np.pad(labels, pad_width=[0, pad_length], constant_values=0)) + padded_character_pos_to_subword_pos.append( + np.pad(character_pos_to_subword_pos, pad_width=[0, pad_length], constant_values=0) + ) + else: + padded_input_ids.append(input_ids) + padded_input_mask.append(input_mask) + padded_segment_ids.append(segment_ids) + padded_labels_mask.append(labels_mask) + padded_labels.append(labels) + padded_character_pos_to_subword_pos.append(character_pos_to_subword_pos) + + if len(input_ids_for_subwords) < max_length_for_subwords: + pad_length = max_length_for_subwords - len(input_ids_for_subwords) + padded_input_ids_for_subwords.append( + np.pad(input_ids_for_subwords, pad_width=[0, pad_length], constant_values=pad_token_id) + ) + padded_input_mask_for_subwords.append( + np.pad(input_mask_for_subwords, pad_width=[0, pad_length], constant_values=0) + ) + padded_segment_ids_for_subwords.append( + np.pad(segment_ids_for_subwords, pad_width=[0, pad_length], constant_values=0) + ) + else: + padded_input_ids_for_subwords.append(input_ids_for_subwords) + padded_input_mask_for_subwords.append(input_mask_for_subwords) + padded_segment_ids_for_subwords.append(segment_ids_for_subwords) + + if len(spans) < max_length_for_spans: + padded_spans.append(np.ones((max_length_for_spans, 3), dtype=int) * -1) # pad value is [-1, -1, -1] + if len(spans) > 0: + padded_spans[-1][: spans.shape[0], : spans.shape[1]] = spans # copy actual spans to the beginning + else: + padded_spans.append(spans) + + return ( + torch.LongTensor(np.array(padded_input_ids)), + torch.LongTensor(np.array(padded_input_mask)), + torch.LongTensor(np.array(padded_segment_ids)), + torch.LongTensor(np.array(padded_input_ids_for_subwords)), + torch.LongTensor(np.array(padded_input_mask_for_subwords)), + torch.LongTensor(np.array(padded_segment_ids_for_subwords)), + torch.LongTensor(np.array(padded_character_pos_to_subword_pos)), + torch.LongTensor(np.array(padded_labels_mask)), + torch.LongTensor(np.array(padded_labels)), + torch.LongTensor(np.array(padded_spans)), + ) + + +def collate_test_dataset( + batch: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]], + pad_token_id: int, +) -> Tuple[ + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, +]: + """collate batch of test items + Args: + batch: A list of tuples of (input_ids, input_mask, segment_ids, input_ids_for_subwords, input_mask_for_subwords, segment_ids_for_subwords, character_pos_to_subword_pos, fragment_indices). + pad_token_id: integer id of padding token (to use in padded_input_ids, padded_input_ids_for_subwords) + """ + max_length = 0 + max_length_for_subwords = 0 + max_length_for_fragment_indices = 1 # to avoid empty tensor + for ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + fragment_indices, + ) in batch: + if len(input_ids) > max_length: + max_length = len(input_ids) + if len(input_ids_for_subwords) > max_length_for_subwords: + max_length_for_subwords = len(input_ids_for_subwords) + if len(fragment_indices) > max_length_for_fragment_indices: + max_length_for_fragment_indices = len(fragment_indices) + + padded_input_ids = [] + padded_input_mask = [] + padded_segment_ids = [] + padded_input_ids_for_subwords = [] + padded_input_mask_for_subwords = [] + padded_segment_ids_for_subwords = [] + padded_character_pos_to_subword_pos = [] + padded_fragment_indices = [] + for ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + fragment_indices, + ) in batch: + if len(input_ids) < max_length: + pad_length = max_length - len(input_ids) + padded_input_ids.append(np.pad(input_ids, pad_width=[0, pad_length], constant_values=pad_token_id)) + padded_input_mask.append(np.pad(input_mask, pad_width=[0, pad_length], constant_values=0)) + padded_segment_ids.append(np.pad(segment_ids, pad_width=[0, pad_length], constant_values=0)) + padded_character_pos_to_subword_pos.append( + np.pad(character_pos_to_subword_pos, pad_width=[0, pad_length], constant_values=0) + ) + else: + padded_input_ids.append(input_ids) + padded_input_mask.append(input_mask) + padded_segment_ids.append(segment_ids) + padded_character_pos_to_subword_pos.append(character_pos_to_subword_pos) + + if len(input_ids_for_subwords) < max_length_for_subwords: + pad_length = max_length_for_subwords - len(input_ids_for_subwords) + padded_input_ids_for_subwords.append( + np.pad(input_ids_for_subwords, pad_width=[0, pad_length], constant_values=pad_token_id) + ) + padded_input_mask_for_subwords.append( + np.pad(input_mask_for_subwords, pad_width=[0, pad_length], constant_values=0) + ) + padded_segment_ids_for_subwords.append( + np.pad(segment_ids_for_subwords, pad_width=[0, pad_length], constant_values=0) + ) + else: + padded_input_ids_for_subwords.append(input_ids_for_subwords) + padded_input_mask_for_subwords.append(input_mask_for_subwords) + padded_segment_ids_for_subwords.append(segment_ids_for_subwords) + + if len(fragment_indices) < max_length_for_fragment_indices: + # we use [0, 1, 0] as padding value for fragment_indices, it corresponds to [CLS] token, which is ignored and won't affect anything + p = np.zeros((max_length_for_fragment_indices, 3), dtype=int) + p[:, 1] = 1 + p[:, 2] = 0 + padded_fragment_indices.append(p) + if len(fragment_indices) > 0: + padded_fragment_indices[-1][ + : fragment_indices.shape[0], : fragment_indices.shape[1] + ] = fragment_indices # copy actual fragment_indices to the beginning + else: + padded_fragment_indices.append(fragment_indices) + + return ( + torch.LongTensor(np.array(padded_input_ids)), + torch.LongTensor(np.array(padded_input_mask)), + torch.LongTensor(np.array(padded_segment_ids)), + torch.LongTensor(np.array(padded_input_ids_for_subwords)), + torch.LongTensor(np.array(padded_input_mask_for_subwords)), + torch.LongTensor(np.array(padded_segment_ids_for_subwords)), + torch.LongTensor(np.array(padded_character_pos_to_subword_pos)), + torch.LongTensor(np.array(padded_fragment_indices)), + ) + + +class SpellcheckingAsrCustomizationDataset(Dataset): + """ + Dataset as used by the SpellcheckingAsrCustomizationModel for training and validation pipelines. + + Args: + input_file (str): path to tsv-file with data + example_builder: instance of BertExampleBuilder + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + return { + "input_ids": NeuralType(('B', 'T'), ChannelType()), + "input_mask": NeuralType(('B', 'T'), MaskType()), + "segment_ids": NeuralType(('B', 'T'), ChannelType()), + "input_ids_for_subwords": NeuralType(('B', 'T'), ChannelType()), + "input_mask_for_subwords": NeuralType(('B', 'T'), MaskType()), + "segment_ids_for_subwords": NeuralType(('B', 'T'), ChannelType()), + "character_pos_to_subword_pos": NeuralType(('B', 'T'), ChannelType()), + "labels_mask": NeuralType(('B', 'T'), MaskType()), + "labels": NeuralType(('B', 'T'), LabelsType()), + "spans": NeuralType(('B', 'T', 'C'), IntType()), + } + + def __init__(self, input_file: str, example_builder: BertExampleBuilder) -> None: + self.example_builder = example_builder + self.examples = self.example_builder.read_input_file(input_file, infer=False) + self.pad_token_id = self.example_builder._pad_id + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx: int): + example = self.examples[idx] + input_ids = np.array(example.features["input_ids"], dtype=np.int16) + input_mask = np.array(example.features["input_mask"], dtype=np.int8) + segment_ids = np.array(example.features["segment_ids"], dtype=np.int8) + input_ids_for_subwords = np.array(example.features["input_ids_for_subwords"], dtype=np.int16) + input_mask_for_subwords = np.array(example.features["input_mask_for_subwords"], dtype=np.int8) + segment_ids_for_subwords = np.array(example.features["segment_ids_for_subwords"], dtype=np.int8) + character_pos_to_subword_pos = np.array(example.features["character_pos_to_subword_pos"], dtype=np.int16) + labels_mask = np.array(example.features["labels_mask"], dtype=np.int8) + labels = np.array(example.features["labels"], dtype=np.int8) + spans = np.array(example.features["spans"], dtype=np.int16) + return ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) + + def _collate_fn(self, batch): + """collate batch of items + Args: + batch: A list of tuples of (input_ids, input_mask, segment_ids, input_ids_for_subwords, input_mask_for_subwords, segment_ids_for_subwords, character_pos_to_subword_pos, labels_mask, labels, spans). + """ + return collate_train_dataset(batch, pad_token_id=self.pad_token_id) + + +class TarredSpellcheckingAsrCustomizationDataset(IterableDataset): + """ + This Dataset loads training examples from tarred tokenized pickle files. + If using multiple processes the number of shards should be divisible by the number of workers to ensure an + even split among workers. If it is not divisible, logging will give a warning but training will proceed. + Additionally, please note that the len() of this DataLayer is assumed to be the number of tokens + of the text data. Shard strategy is scatter - each node gets a unique set of shards, which are permanently + pre-allocated and never changed at runtime. + Args: + text_tar_filepaths: a string (can be brace-expandable). + shuffle_n (int): How many samples to look ahead and load to be shuffled. + See WebDataset documentation for more details. + Defaults to 0. + global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. + world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. + pad_token_id: id of pad token (used in collate_fn) + """ + + def __init__( + self, + text_tar_filepaths: str, + shuffle_n: int = 1, + global_rank: int = 0, + world_size: int = 1, + pad_token_id: int = -1, # use real value or get error + ): + super(TarredSpellcheckingAsrCustomizationDataset, self).__init__() + if pad_token_id < 0: + raise ValueError("use non-negative pad_token_id: " + str(pad_token_id)) + + self.pad_token_id = pad_token_id + + # Replace '(', '[', '<' and '_OP_' with '{' + brace_keys_open = ['(', '[', '<', '_OP_'] + for bkey in brace_keys_open: + if bkey in text_tar_filepaths: + text_tar_filepaths = text_tar_filepaths.replace(bkey, "{") + + # Replace ')', ']', '>' and '_CL_' with '}' + brace_keys_close = [')', ']', '>', '_CL_'] + for bkey in brace_keys_close: + if bkey in text_tar_filepaths: + text_tar_filepaths = text_tar_filepaths.replace(bkey, "}") + + # Brace expand + text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) + + logging.info("Tarred dataset shards will be scattered evenly across all nodes.") + if len(text_tar_filepaths) % world_size != 0: + logging.warning( + f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " + f"by number of distributed workers ({world_size}). " + f"Some shards will not be used ({len(text_tar_filepaths) % world_size})." + ) + begin_idx = (len(text_tar_filepaths) // world_size) * global_rank + end_idx = begin_idx + (len(text_tar_filepaths) // world_size) + logging.info('Begin Index : %d' % (begin_idx)) + logging.info('End Index : %d' % (end_idx)) + text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx] + logging.info( + "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx + ) + + self.tarpath = text_tar_filepaths + + # Put together WebDataset + self._dataset = wd.WebDataset(urls=text_tar_filepaths, nodesplitter=None) + + if shuffle_n > 0: + self._dataset = self._dataset.shuffle(shuffle_n, initial=shuffle_n) + else: + logging.info("WebDataset will not shuffle files within the tar files.") + + self._dataset = self._dataset.rename(pkl='pkl', key='__key__').to_tuple('pkl', 'key').map(f=self._build_sample) + + def _build_sample(self, fname): + # Load file + pkl_file, _ = fname + pkl_file = BytesIO(pkl_file) + data = pickle.load(pkl_file) + pkl_file.close() + input_ids = data["input_ids"] + input_mask = data["input_mask"] + segment_ids = data["segment_ids"] + input_ids_for_subwords = data["input_ids_for_subwords"] + input_mask_for_subwords = data["input_mask_for_subwords"] + segment_ids_for_subwords = data["segment_ids_for_subwords"] + character_pos_to_subword_pos = data["character_pos_to_subword_pos"] + labels_mask = data["labels_mask"] + labels = data["labels"] + spans = data["spans"] + + return ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) + + def __iter__(self): + return self._dataset.__iter__() + + def _collate_fn(self, batch): + """collate batch of items + Args: + batch: A list of tuples of (input_ids, input_mask, segment_ids, input_ids_for_subwords, input_mask_for_subwords, segment_ids_for_subwords, character_pos_to_subword_pos, labels_mask, labels, spans). + """ + return collate_train_dataset(batch, pad_token_id=self.pad_token_id) + + +class SpellcheckingAsrCustomizationTestDataset(Dataset): + """ + Dataset for inference pipeline. + + Args: + sents: list of strings + example_builder: instance of BertExampleBuilder + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + return { + "input_ids": NeuralType(('B', 'T'), ChannelType()), + "input_mask": NeuralType(('B', 'T'), MaskType()), + "segment_ids": NeuralType(('B', 'T'), ChannelType()), + "input_ids_for_subwords": NeuralType(('B', 'T'), ChannelType()), + "input_mask_for_subwords": NeuralType(('B', 'T'), MaskType()), + "segment_ids_for_subwords": NeuralType(('B', 'T'), ChannelType()), + "character_pos_to_subword_pos": NeuralType(('B', 'T'), ChannelType()), + "fragment_indices": NeuralType(('B', 'T', 'C'), IntType()), + } + + def __init__(self, input_file: str, example_builder: BertExampleBuilder) -> None: + self.example_builder = example_builder + self.examples, self.hyps_refs = self.example_builder.read_input_file(input_file, infer=True) + self.pad_token_id = self.example_builder._pad_id + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx: int): + example = self.examples[idx] + input_ids = np.array(example.features["input_ids"]) + input_mask = np.array(example.features["input_mask"]) + segment_ids = np.array(example.features["segment_ids"]) + input_ids_for_subwords = np.array(example.features["input_ids_for_subwords"]) + input_mask_for_subwords = np.array(example.features["input_mask_for_subwords"]) + segment_ids_for_subwords = np.array(example.features["segment_ids_for_subwords"]) + character_pos_to_subword_pos = np.array(example.features["character_pos_to_subword_pos"], dtype=np.int64) + fragment_indices = np.array(example.features["fragment_indices"], dtype=np.int16) + return ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + fragment_indices, + ) + + def _collate_fn(self, batch): + """collate batch of items + Args: + batch: A list of tuples of (input_ids, input_mask, segment_ids, input_ids_for_subwords, input_mask_for_subwords, segment_ids_for_subwords, character_pos_to_subword_pos). + """ + return collate_test_dataset(batch, pad_token_id=self.pad_token_id) diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py new file mode 100644 index 000000000000..cda551189d78 --- /dev/null +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py @@ -0,0 +1,845 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import math +import random +import re +from collections import defaultdict, namedtuple +from typing import Dict, List, Set, Tuple, Union + +import numpy as np +from numba import jit + +"""Utility functions for Spellchecking ASR Customization.""" + + +def replace_diacritics(text): + text = re.sub(r"[éèëēêęěė]", "e", text) # latin + text = re.sub(r"[ё]", "е", text) # cyrillic + text = re.sub(r"[ãâāáäăàąåạảǎ]", "a", text) + text = re.sub(r"[úūüùưûů]", "u", text) + text = re.sub(r"[ôōóöõòőø]", "o", text) + text = re.sub(r"[ćçč]", "c", text) + text = re.sub(r"[ïīíîıì]", "i", text) + text = re.sub(r"[ñńňņ]", "n", text) + text = re.sub(r"[țťţ]", "t", text) + text = re.sub(r"[łľļ]", "l", text) + text = re.sub(r"[żžź]", "z", text) + text = re.sub(r"[ğ]", "g", text) + text = re.sub(r"[ďđ]", "d", text) + text = re.sub(r"[ķ]", "k", text) + text = re.sub(r"[ř]", "r", text) + text = re.sub(r"[ý]", "y", text) + text = re.sub(r"[æ]", "ae", text) + text = re.sub(r"[œ]", "oe", text) + text = re.sub(r"[șşšś]", "s", text) + return text + + +def load_ngram_mappings(input_name: str, max_misspelled_freq: int = 1000000000) -> Tuple[defaultdict, Set]: + """Loads n-gram mapping vocabularies in form required by dynamic programming + Args: + input_name: file with n-gram mappings + max_misspelled_freq: threshold on misspelled n-gram frequency + Returns: + vocab: dict {key=original_ngram, value=dict{key=misspelled_ngram, value=frequency}} + ban_ngram: set of banned misspelled n-grams + + Input format: + u t o u+i t o 49 8145 114 + u t o t e 63 8145 16970 + u t o o+_ t o 42 8145 1807 + """ + vocab = defaultdict(dict) + ban_ngram = set() + + with open(input_name, "r", encoding="utf-8") as f: + for line in f: + orig, misspelled, joint_freq, orig_freq, misspelled_freq = line.strip().split("\t") + if orig == "" or misspelled == "": + raise ValueError("Empty n-gram: orig=" + orig + "; misspelled=" + misspelled) + misspelled = misspelled.replace("", "=") + if misspelled.replace("=", "").strip() == "": # skip if resulting ngram doesn't contain any real character + continue + if int(misspelled_freq) > max_misspelled_freq: + ban_ngram.add(misspelled + " ") # space at the end is required within get_index function + vocab[orig][misspelled] = int(joint_freq) / int(orig_freq) + return vocab, ban_ngram + + +def load_ngram_mappings_for_dp(input_name: str) -> Tuple[defaultdict, defaultdict, defaultdict, int]: + """Loads n-gram mapping vocabularies in form required by dynamic programming + Args: + input_name: file with n-gram mappings + Returns: + joint_vocab: dict where key=(original_ngram, misspelled_ngram), value=frequency + orig_vocab: dict where key=original_ngram, value=frequency + misspelled_vocab: dict where key=misspelled_ngram, value=frequency + max_len: maximum n-gram length seen in vocabulary + + Input format: original \t misspelled \t joint_freq \t original_freq \t misspelled_freq + u t o u+i t o 49 8145 114 + u t o t e 63 8145 16970 + u t o o+_ t o 42 8145 1807 + """ + joint_vocab = defaultdict(int) + orig_vocab = defaultdict(int) + misspelled_vocab = defaultdict(int) + max_len = 0 + with open(input_name, "r", encoding="utf-8") as f: + for line in f: + orig, misspelled, joint_freq, _, _ = line.strip().split("\t") + if orig == "" or misspelled == "": + raise ValueError("Emty n-gram: orig=" + orig + "; misspelled=" + misspelled) + misspelled = misspelled.replace("", " ").replace("+", " ") + misspelled = " ".join(misspelled.split()) + if misspelled == "": # skip if resulting ngram doesn't contain any real character + continue + max_len = max(max_len, orig.count(" ") + 1, misspelled.count(" ") + 1) + joint_vocab[(orig, misspelled)] += int(joint_freq) + orig_vocab[orig] += int(joint_freq) + misspelled_vocab[misspelled] += int(joint_freq) + return joint_vocab, orig_vocab, misspelled_vocab, max_len + + +def get_alignment_by_dp( + ref_phrase: str, hyp_phrase: str, dp_data: Tuple[defaultdict, defaultdict, defaultdict, int] +) -> List[Tuple[str, str, float, float, int, int, int]]: + """Get best alignment path between a reference and (possibly) misspelled phrase using n-gram mappings vocabulary. + Args: + ref_phrase: candidate reference phrase (letters separated by space, real space replaced by underscore) + hyp_phrase: (possibly) misspelled phrase (letters separated by space, real space replaced by underscore) + dp_data: n-gram mapping vocabularies used by dynamic programming + Returns: + list of tuples (hyp_ngram, ref_ngram, logprob, sum_logprob, joint_freq, orig_freq, misspelled_freq) + This is best alignment path. + + Example: + ref_phrase: "a n h y d r i d e" + hyp_phrase: "a n d _ h y d r o d" + + Result: + [("*", "*", 0.0, 0.0, 0, 0, 0) + ("a n d _ h", "a n h", -2.34, -2.34, 226, 2338, 2203) + ("y d r o", "y d r i", -2.95, -5.29, 11, 211, 1584) + ("d", "d e", -1.99, -7.28, 60610, 444714, 2450334) + ] + Final path score is in path[-1][3]: -7.28 + Note that the order of ref_phrase and hyp_phrase matters, because n-gram mappings vocabulary is not symmetrical. + """ + joint_vocab, orig_vocab, misspelled_vocab, max_len = dp_data + hyp_letters = ["*"] + hyp_phrase.split() + ref_letters = ["*"] + ref_phrase.split() + DpInfo = namedtuple( + "DpInfo", ["hyp_pos", "ref_pos", "best_hyp_ngram_len", "best_ref_ngram_len", "score", "sum_score"] + ) + history = defaultdict(DpInfo) + history[(0, 0)] = DpInfo( + hyp_pos=0, ref_pos=0, best_hyp_ngram_len=1, best_ref_ngram_len=1, score=0.0, sum_score=0.0 + ) + for hyp_pos in range(len(hyp_letters)): + for ref_pos in range(len(ref_letters)): + if hyp_pos == 0 and ref_pos == 0: # cell (0, 0) is already defined + continue + # consider cell (hyp_pos, ref_pos) and find best path to get there + best_hyp_ngram_len = 0 + best_ref_ngram_len = 0 + best_ngram_score = float("-inf") + best_sum_score = float("-inf") + # loop over paths ending on non-empty ngram mapping + for hyp_ngram_len in range(1, 1 + min(max_len, hyp_pos + 1)): + hyp_ngram = " ".join(hyp_letters[(hyp_pos - hyp_ngram_len + 1) : (hyp_pos + 1)]) + for ref_ngram_len in range(1, 1 + min(max_len, ref_pos + 1)): + ref_ngram = " ".join(ref_letters[(ref_pos - ref_ngram_len + 1) : (ref_pos + 1)]) + if (ref_ngram, hyp_ngram) not in joint_vocab: + continue + joint_freq = joint_vocab[(ref_ngram, hyp_ngram)] + orig_freq = orig_vocab.get(ref_ngram, 1) + ngram_score = math.log(joint_freq / orig_freq) + previous_cell = (hyp_pos - hyp_ngram_len, ref_pos - ref_ngram_len) + if previous_cell not in history: + print("cell ", previous_cell, "does not exist") + continue + previous_score = history[previous_cell].sum_score + sum_score = ngram_score + previous_score + if sum_score > best_sum_score: + best_sum_score = sum_score + best_ngram_score = ngram_score + best_hyp_ngram_len = hyp_ngram_len + best_ref_ngram_len = ref_ngram_len + # loop over two variants with deletion of one character + deletion_score = -6.0 + insertion_score = -6.0 + if hyp_pos > 0: + previous_cell = (hyp_pos - 1, ref_pos) + previous_score = history[previous_cell].sum_score + sum_score = deletion_score + previous_score + if sum_score > best_sum_score: + best_sum_score = sum_score + best_ngram_score = deletion_score + best_hyp_ngram_len = 1 + best_ref_ngram_len = 0 + + if ref_pos > 0: + previous_cell = (hyp_pos, ref_pos - 1) + previous_score = history[previous_cell].sum_score + sum_score = insertion_score + previous_score + if sum_score > best_sum_score: + best_sum_score = sum_score + best_ngram_score = insertion_score + best_hyp_ngram_len = 0 + best_ref_ngram_len = 1 + + if best_hyp_ngram_len == 0 and best_ref_ngram_len == 0: + raise ValueError("best_hyp_ngram_len = 0 and best_ref_ngram_len = 0") + + # save cell to history + history[(hyp_pos, ref_pos)] = DpInfo( + hyp_pos=hyp_pos, + ref_pos=ref_pos, + best_hyp_ngram_len=best_hyp_ngram_len, + best_ref_ngram_len=best_ref_ngram_len, + score=best_ngram_score, + sum_score=best_sum_score, + ) + # now trace back on best path starting from last positions + path = [] + hyp_pos = len(hyp_letters) - 1 + ref_pos = len(ref_letters) - 1 + cell_info = history[(hyp_pos, ref_pos)] + path.append(cell_info) + while hyp_pos > 0 or ref_pos > 0: + hyp_pos -= cell_info.best_hyp_ngram_len + ref_pos -= cell_info.best_ref_ngram_len + cell_info = history[(hyp_pos, ref_pos)] + path.append(cell_info) + + result = [] + for info in reversed(path): + hyp_ngram = " ".join(hyp_letters[(info.hyp_pos - info.best_hyp_ngram_len + 1) : (info.hyp_pos + 1)]) + ref_ngram = " ".join(ref_letters[(info.ref_pos - info.best_ref_ngram_len + 1) : (info.ref_pos + 1)]) + joint_freq = joint_vocab.get((ref_ngram, hyp_ngram), 0) + orig_freq = orig_vocab.get(ref_ngram, 0) + misspelled_freq = misspelled_vocab.get(hyp_ngram, 0) + result.append((hyp_ngram, ref_ngram, info.score, info.sum_score, joint_freq, orig_freq, misspelled_freq)) + return result + + +def get_index( + custom_phrases: List[str], + vocab: defaultdict, + ban_ngram_global: Set[str], + min_log_prob: float = -4.0, + max_phrases_per_ngram: int = 100, +) -> Tuple[List[str], Dict[str, List[Tuple[int, int, int, float]]]]: + """Given a restricted vocabulary of replacements, + loops through custom phrases, + generates all possible conversions and creates index. + + Args: + custom_phrases: list of all custom phrases, characters should be split by space, real space replaced to underscore. + vocab: n-gram mappings vocabulary - dict {key=original_ngram, value=dict{key=misspelled_ngram, value=frequency}} + ban_ngram_global: set of banned misspelled n-grams + min_log_prob: minimum log probability, after which we stop growing this n-gram. + max_phrases_per_ngram: maximum phrases that we allow to store per one n-gram. N-grams exceeding that quantity get banned. + + Returns: + phrases - list of phrases. Position in this list is used as phrase_id. + ngram2phrases - resulting index, i.e. dict where key=ngram, value=list of tuples (phrase_id, begin_pos, size, logprob) + """ + + ban_ngram_local = set() # these ngrams are banned only for given custom_phrases + ngram_to_phrase_and_position = defaultdict(list) + + for custom_phrase in custom_phrases: + inputs = custom_phrase.split(" ") + begin = 0 + index_keys = [{} for _ in inputs] # key - letter ngram, index - beginning positions in phrase + + for begin in range(len(inputs)): + for end in range(begin + 1, min(len(inputs) + 1, begin + 5)): + inp = " ".join(inputs[begin:end]) + if inp not in vocab: + continue + for rep in vocab[inp]: + lp = math.log(vocab[inp][rep]) + + for b in range(max(0, end - 5), end): # try to grow previous ngrams with new replacement + new_ngrams = {} + for ngram in index_keys[b]: + lp_prev = index_keys[b][ngram] + if len(ngram) + len(rep) <= 10 and b + ngram.count(" ") == begin: + if lp_prev + lp > min_log_prob: + new_ngrams[ngram + rep + " "] = lp_prev + lp + index_keys[b].update(new_ngrams) # join two dictionaries + # add current replacement as ngram + if lp > min_log_prob: + index_keys[begin][rep + " "] = lp + + for b in range(len(index_keys)): + for ngram, lp in sorted(index_keys[b].items(), key=lambda item: item[1], reverse=True): + if ngram in ban_ngram_global: # here ngram ends with a space + continue + real_length = ngram.count(" ") + ngram = ngram.replace("+", " ").replace("=", " ") + ngram = " ".join(ngram.split()) # here ngram doesn't end with a space anymore + if ngram + " " in ban_ngram_global: # this can happen after deletion of + and = + continue + if ngram in ban_ngram_local: + continue + ngram_to_phrase_and_position[ngram].append((custom_phrase, b, real_length, lp)) + if len(ngram_to_phrase_and_position[ngram]) > max_phrases_per_ngram: + ban_ngram_local.add(ngram) + del ngram_to_phrase_and_position[ngram] + continue + + phrases = [] # id to phrase + phrase2id = {} # phrase to id + ngram2phrases = defaultdict(list) # ngram to list of tuples (phrase_id, begin, length, logprob) + + for ngram in ngram_to_phrase_and_position: + for phrase, b, length, lp in ngram_to_phrase_and_position[ngram]: + if phrase not in phrase2id: + phrases.append(phrase) + phrase2id[phrase] = len(phrases) - 1 + ngram2phrases[ngram].append((phrase2id[phrase], b, length, lp)) + + return phrases, ngram2phrases + + +def load_index(input_name: str) -> Tuple[List[str], Dict[str, List[Tuple[int, int, int, float]]]]: + """ Load index from file + Args: + input_name: file with index + Returns: + phrases: List of all phrases in custom vocabulary. Position corresponds to phrase_id. + ngram2phrases: dict where key=ngram, value=list of tuples (phrase_id, begin_pos, size, logprob) + """ + phrases = [] # id to phrase + phrase2id = {} # phrase to id + ngram2phrases = defaultdict(list) # ngram to list of tuples (phrase_id, begin_pos, size, logprob) + with open(input_name, "r", encoding="utf-8") as f: + for line in f: + ngram, phrase, b, size, lp = line.split("\t") + b = int(b) + size = int(size) + lp = float(lp) + if phrase not in phrase2id: + phrases.append(phrase) + phrase2id[phrase] = len(phrases) - 1 + ngram2phrases[ngram].append((phrase2id[phrase], b, size, lp)) + return phrases, ngram2phrases + + +def search_in_index( + ngram2phrases: Dict[str, List[Tuple[int, int, int, float]]], phrases: List[str], letters: Union[str, List[str]] +) -> Tuple[np.ndarray, List[Set[str]]]: + """ Function used to search in index + + Args: + ngram2phrases: dict where key=ngram, value=list of tuples (phrase_id, begin_pos, size, logprob) + phrases: List of all phrases in custom vocabulary. Position corresponds to phrase_id. + letters: list of letters of ASR-hypothesis. Should not contain spaces - real spaces should be replaced with underscores. + + Returns: + phrases2positions: a matrix of size (len(phrases), len(letters)). + It is filled with 1.0 (hits) on intersection of letter n-grams and phrases that are indexed by these n-grams, 0.0 - elsewhere. + It is used later to find phrases with many hits within a contiguous window - potential matching candidates. + position2ngrams: positions in ASR-hypothesis mapped to sets of ngrams starting from that position. + It is used later to check how well each found candidate is covered by n-grams (to avoid cases where some repeating n-gram gives many hits to a phrase, but the phrase itself is not well covered). + """ + + if " " in letters: + raise ValueError("letters should not contain space: " + str(letters)) + + phrases2positions = np.zeros((len(phrases), len(letters)), dtype=float) + # positions mapped to sets of ngrams starting from that position + position2ngrams = [set() for _ in range(len(letters))] + + begin = 0 + for begin in range(len(letters)): + for end in range(begin + 1, min(len(letters) + 1, begin + 7)): + ngram = " ".join(letters[begin:end]) + if ngram not in ngram2phrases: + continue + for phrase_id, b, size, lp in ngram2phrases[ngram]: + phrases2positions[phrase_id, begin:end] = 1.0 + position2ngrams[begin].add(ngram) + return phrases2positions, position2ngrams + + +@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit +def get_all_candidates_coverage(phrases, phrases2positions): + """Get maximum hit coverage for each phrase - within a moving window of length of the phrase. + Args: + phrases: List of all phrases in custom vocabulary. Position corresponds to phrase_id. + phrases2positions: a matrix of size (len(phrases), len(ASR-hypothesis)). + It is filled with 1.0 (hits) on intersection of letter n-grams and phrases that are indexed by these n-grams, 0.0 - elsewhere. + Returns: + candidate2coverage: list of size len(phrases) containing coverage (0.0 to 1.0) in best window. + candidate2position: list of size len(phrases) containing starting position of best window. + """ + candidate2coverage = [0.0] * len(phrases) + candidate2position = [-1] * len(phrases) + + for i in range(len(phrases)): + phrase_length = phrases[i].count(" ") + 1 + all_coverage = np.sum(phrases2positions[i]) / phrase_length + # if total coverage on whole ASR-hypothesis is too small, there is no sense in using moving window + if all_coverage < 0.4: + continue + moving_sum = np.sum(phrases2positions[i, 0:phrase_length]) + max_sum = moving_sum + best_pos = 0 + for pos in range(1, phrases2positions.shape[1] - phrase_length + 1): + moving_sum -= phrases2positions[i, pos - 1] + moving_sum += phrases2positions[i, pos + phrase_length - 1] + if moving_sum > max_sum: + max_sum = moving_sum + best_pos = pos + + coverage = max_sum / (phrase_length + 2) # smoothing + candidate2coverage[i] = coverage + candidate2position[i] = best_pos + return candidate2coverage, candidate2position + + +def get_candidates( + ngram2phrases: Dict[str, List[Tuple[int, int, int, float]]], + phrases: List[str], + letters: Union[str, List[str]], + pool_for_random_candidates: List[str], + min_phrase_coverage: float = 0.8, +) -> List[Tuple[str, int, int, float, float]]: + """Given an index of custom vocabulary and an ASR-hypothesis retrieve 10 candidates. + Args: + ngram2phrases: dict where key=ngram, value=list of tuples (phrase_id, begin_pos, size, logprob) + phrases: List of all phrases in custom vocabulary. Position corresponds to phrase_id. + letters: list of letters of ASR-hypothesis. Should not contain spaces - real spaces should be replaced with underscores. + pool_for_random_candidates: large list of strings, from which to sample random candidates in case when there are less than 10 real candidates + min_phrase_coverage: We discard candidates which are not covered by n-grams to at least to this extent + (to avoid cases where some repeating n-gram gives many hits to a phrase, but the phrase itself is not well covered). + Returns: + candidates: list of tuples (candidate_text, approximate_begin_position, length, coverage of window in ASR-hypothesis, coverage of phrase itself). + """ + phrases2positions, position2ngrams = search_in_index(ngram2phrases, phrases, letters) + candidate2coverage, candidate2position = get_all_candidates_coverage(phrases, phrases2positions) + + # mask for each custom phrase, how many which symbols are covered by input ngrams + phrases2coveredsymbols = [[0 for x in phrases[i].split(" ")] for i in range(len(phrases))] + candidates = [] + k = 0 + for idx, coverage in sorted(enumerate(candidate2coverage), key=lambda item: item[1], reverse=True): + begin = candidate2position[idx] # this is most likely beginning of this candidate + phrase_length = phrases[idx].count(" ") + 1 + for pos in range(begin, begin + phrase_length): + # we do not know exact end of custom phrase in text, it can be different from phrase length + if pos >= len(position2ngrams): + break + for ngram in position2ngrams[pos]: + for phrase_id, b, size, lp in ngram2phrases[ngram]: + if phrase_id != idx: + continue + for ppos in range(b, b + size): + if ppos >= phrase_length: + break + phrases2coveredsymbols[phrase_id][ppos] = 1 + k += 1 + if k > 100: + break + real_coverage = sum(phrases2coveredsymbols[idx]) / len(phrases2coveredsymbols[idx]) + if real_coverage < min_phrase_coverage: + continue + candidates.append((phrases[idx], begin, phrase_length, coverage, real_coverage)) + + # no need to process this sentence further if it does not contain any real candidates + if len(candidates) == 0: + print("WARNING: no real candidates", candidates) + return [] + + while len(candidates) < 10: + dummy = random.choice(pool_for_random_candidates) + dummy = " ".join(list(dummy.replace(" ", "_"))) + candidates.append((dummy, -1, dummy.count(" ") + 1, 0.0, 0.0)) + + candidates = candidates[:10] + random.shuffle(candidates) + if len(candidates) != 10: + print("WARNING: cannot get 10 candidates", candidates) + return [] + + return candidates + + +def read_spellmapper_predictions(filename: str) -> List[Tuple[str, List[Tuple[int, int, str, float]], List[int]]]: + """Read results of SpellMapper inference from file. + Args: + filename: file with SpellMapper results + Returns: + list of tuples (sent, list of fragment predictions, list of letter predictions) + One fragment prediction is a tuple (begin, end, replacement_text, prob) + """ + results = [] + with open(filename, "r", encoding="utf-8") as f: + for line in f: + text, candidate_str, fragment_predictions_str, letter_predictions_str = line.strip().split("\t") + text = text.replace(" ", "").replace("_", " ") + candidate_str = candidate_str.replace(" ", "").replace("_", " ") + candidates = candidate_str.split(";") + letter_predictions = list(map(int, letter_predictions_str.split())) + if len(candidates) != 10: + raise IndexError("expect 10 candidates, got: ", len(candidates)) + if len(text) != len(letter_predictions): + raise IndexError("len(text)=", len(text), "; len(letter_predictions)=", len(letter_predictions)) + replacements = [] + if fragment_predictions_str != "": + for prediction in fragment_predictions_str.split(";"): + begin, end, candidate_id, prob = prediction.split(" ") + begin = int(begin) + end = int(end) + candidate_id = int(candidate_id) + prob = float(prob) + replacements.append((begin, end, candidates[candidate_id - 1], prob)) + replacements.sort() # it will sort by begin, then by end + results.append((text, replacements, letter_predictions)) + return results + + +def substitute_replacements_in_text( + text: str, replacements: List[Tuple[int, int, str, float]], replace_hyphen_to_space: bool +) -> str: + """Substitute replacements to the input text, iterating from end to beginning, so that indexing does not change. + Note that we expect intersecting replacements to be already filtered. + Args: + text: sentence; + replacements: list of replacements, each is a tuple (begin, end, text, probability); + replace_hyphen_to_space: if True, hyphens in replacements will be converted to spaces; + Returns: + corrected sentence + """ + replacements.sort() + last_begin = len(text) + 1 + corrected_text = text + for begin, end, candidate, prob in reversed(replacements): + if end > last_begin: + print("WARNING: skip intersecting replacement [", candidate, "] in text: ", text) + continue + if replace_hyphen_to_space: + candidate = candidate.replace("-", " ") + corrected_text = corrected_text[:begin] + candidate + corrected_text[end:] + last_begin = begin + return corrected_text + + +def apply_replacements_to_text( + text: str, + replacements: List[Tuple[int, int, str, float]], + min_prob: float = 0.5, + replace_hyphen_to_space: bool = False, + dp_data: Tuple[defaultdict, defaultdict, defaultdict, int] = None, + min_dp_score_per_symbol: float = -99.9, +) -> str: + """Filter and apply replacements to the input sentence. + Args: + text: input sentence; + replacements: list of proposed replacements (probably intersecting), each is a tuple (begin, end, text, probability); + min_prob: threshold on replacement probability; + replace_hyphen_to_space: if True, hyphens in replacements will be converted to spaces; + dp_data: n-gram mapping vocabularies used by dynamic programming, if None - dynamic programming is not used; + min_dp_score_per_symbol: threshold on dynamic programming sum score averaged by hypothesis length + Returns: + corrected sentence + """ + # sort replacements by positions + replacements.sort() + # filter replacements + # Note that we do not skip replacements with same text, otherwise intersecting candidates with lower probability can win + filtered_replacements = [] + for j in range(len(replacements)): + replacement = replacements[j] + begin, end, candidate, prob = replacement + fragment = text[begin:end] + candidate_spaced = " ".join(list(candidate.replace(" ", "_"))) + fragment_spaced = " ".join(list(fragment.replace(" ", "_"))) + # apply penalty if candidate length is bigger than fragment length + # to avoid cases like "forward-looking" replacing "looking" in "forward looking" resulting in "forward forward looking" + if len(candidate) > len(fragment): + penalty = len(fragment) / len(candidate) + prob *= penalty + # skip replacement with low probability + if prob < min_prob: + continue + # skip replacements with some predefined templates, e.g. "*'s" => "*s" + if check_banned_replacements(fragment, candidate): + continue + if dp_data is not None: + path = get_alignment_by_dp(candidate_spaced, fragment_spaced, dp_data) + # path[-1][3] is the sum of logprobs for best path of dynamic programming: divide sum_score by length + if path[-1][3] / (len(fragment)) < min_dp_score_per_symbol: + continue + + # skip replacement if it intersects with previous replacement and has lower probability, otherwise remove previous replacement + if len(filtered_replacements) > 0 and filtered_replacements[-1][1] > begin: + if filtered_replacements[-1][3] > prob: + continue + else: + filtered_replacements.pop() + filtered_replacements.append((begin, end, candidate, prob)) + + return substitute_replacements_in_text(text, filtered_replacements, replace_hyphen_to_space) + + +def update_manifest_with_spellmapper_corrections( + input_manifest_name: str, + short2full_name: str, + output_manifest_name: str, + spellmapper_results_name: str, + min_prob: float = 0.5, + replace_hyphen_to_space: bool = True, + field_name: str = "pred_text", + use_dp: bool = True, + ngram_mappings: Union[str, None] = None, + min_dp_score_per_symbol: float = -1.5, +) -> None: + """Post-process SpellMapper predictions and write corrected sentence to the specified field of nemo manifest. + The previous content of this field will be copied to "*_before_correction" field. + If the sentence was split into fragments before running SpellMapper, all replacements will be first gathered together and then applied to the original long sentence. + Args: + input_manifest_name: input nemo manifest; + short2full_name: text file with two columns: short_sent \t full_sent; + output_manifest_name: output nemo manifest; + spellmapper_results_name: text file with SpellMapper inference results; + min_prob: threshold on replacement probability; + replace_hyphen_to_space: if True, hyphens in replacements will be converted to spaces; + field_name: name of json field whose text we want to correct; + use_dp: bool = If True, additional replacement filtering will be applied using dynamic programming (works slow); + ngram_mappings: file with n-gram mappings, only needed if use_dp=True + min_dp_score_per_symbol: threshold on dynamic programming sum score averaged by hypothesis length + """ + short2full_sent = defaultdict(list) + sent2corrections = defaultdict(dict) + with open(short2full_name, "r", encoding="utf-8") as f: + for line in f: + s = line.strip() + short_sent, full_sent = s.split("\t") + short2full_sent[short_sent].append(full_sent) + sent2corrections[full_sent] = [] + + spellmapper_results = read_spellmapper_predictions(spellmapper_results_name) + dp_data = None + if use_dp: + dp_data = load_ngram_mappings_for_dp(ngram_mappings) + + for text, replacements, _ in spellmapper_results: + short_sent = text + if short_sent not in short2full_sent: + continue + # it can happen that one short sentence occurred in multiple full sentences + for full_sent in short2full_sent[short_sent]: + offset = full_sent.find(short_sent) + for begin, end, candidate, prob in replacements: + sent2corrections[full_sent].append((begin + offset, end + offset, candidate, prob)) + + out = open(output_manifest_name, "w", encoding="utf-8") + with open(input_manifest_name, "r", encoding="utf-8") as f: + for line in f: + record = json.loads(line.strip()) + sent = record[field_name] + record[field_name + "_before_correction"] = record[field_name] + if sent in sent2corrections: + record[field_name] = apply_replacements_to_text( + sent, + sent2corrections[sent], + min_prob=min_prob, + replace_hyphen_to_space=replace_hyphen_to_space, + dp_data=dp_data, + min_dp_score_per_symbol=min_dp_score_per_symbol, + ) + out.write(json.dumps(record) + "\n") + out.close() + + +def extract_and_split_text_from_manifest( + input_name: str, output_name: str, field_name: str = "pred_text", len_in_words: int = 16, step_in_words: int = 8 +) -> None: + """Extract text of the specified field in nemo manifest and split it into fragments (possibly with intersection). + The result is saved to a text file with two columns: short_sent \t full_sent. + This is useful if we want to process shorter sentences and then apply the results to the original long sentence. + Args: + input_name: input nemo manifest, + output_name: output text file, + field_name: name of json field from which we extract the sentence text, + len_in_words: maximum number of words in a fragment, + step_in_words: on how many words we move at each step. + For example, if the len_in_words=16 and step_in_words=8 the fragments will be intersected by half. + """ + short2full_sent = set() + with open(input_name, "r", encoding="utf-8") as f: + for line in f: + record = json.loads(line.strip()) + sent = record[field_name] + if " " in sent: + raise ValueError("found multiple space in: " + sent) + words = sent.split() + for i in range(0, len(words), step_in_words): + short_sent = " ".join(words[i : i + len_in_words]) + short2full_sent.add((short_sent, sent)) + + with open(output_name, "w", encoding="utf-8") as out: + for short_sent, full_sent in short2full_sent: + out.write(short_sent + "\t" + full_sent + "\n") + + +def check_banned_replacements(src: str, dst: str) -> bool: + """This function is used to check is a pair of words/phrases is matching some common template that we don't want to replace with one another. + Args: + src: first phrase + dst: second phrase + Returns True if this replacement should be banned. + """ + # customers' => customer's + if src.endswith("s'") and dst.endswith("'s") and src[0:-2] == dst[0:-2]: + return True + # customer's => customers' + if src.endswith("'s") and dst.endswith("s'") and src[0:-2] == dst[0:-2]: + return True + # customers => customer's + if src.endswith("s") and dst.endswith("'s") and src[0:-1] == dst[0:-2]: + return True + # customer's => customers + if src.endswith("'s") and dst.endswith("s") and src[0:-2] == dst[0:-1]: + return True + # customers => customers' + if src.endswith("s") and dst.endswith("s'") and src[0:-1] == dst[0:-2]: + return True + # customers' => customers + if src.endswith("s'") and dst.endswith("s") and src[0:-2] == dst[0:-1]: + return True + # utilities => utility's + if src.endswith("ies") and dst.endswith("y's") and src[0:-3] == dst[0:-3]: + return True + # utility's => utilities + if src.endswith("y's") and dst.endswith("ies") and src[0:-3] == dst[0:-3]: + return True + # utilities => utility + if src.endswith("ies") and dst.endswith("y") and src[0:-3] == dst[0:-1]: + return True + # utility => utilities + if src.endswith("y") and dst.endswith("ies") and src[0:-1] == dst[0:-3]: + return True + # group is => group's + if src.endswith(" is") and dst.endswith("'s") and src[0:-3] == dst[0:-2]: + return True + # group's => group is + if src.endswith("'s") and dst.endswith(" is") and src[0:-2] == dst[0:-3]: + return True + # trex's => trex + if src.endswith("'s") and src[0:-2] == dst: + return True + # trex => trex's + if dst.endswith("'s") and dst[0:-2] == src: + return True + # increases => increase (but trimass => trimas is ok) + if src.endswith("s") and (not src.endswith("ss")) and src[0:-1] == dst: + return True + # increase => increases ((but trimas => trimass is ok)) + if dst.endswith("s") and (not dst.endswith("ss")) and dst[0:-1] == src: + return True + # anticipate => anticipated + if src.endswith("e") and dst.endswith("ed") and src[0:-1] == dst[0:-2]: + return True + # anticipated => anticipate + if src.endswith("ed") and dst.endswith("e") and src[0:-2] == dst[0:-1]: + return True + # regarded => regard + if src.endswith("ed") and src[0:-2] == dst: + return True + # regard => regarded + if dst.endswith("ed") and dst[0:-2] == src: + return True + # longer => long + if src.endswith("er") and src[0:-2] == dst: + return True + # long => longer + if dst.endswith("er") and dst[0:-2] == src: + return True + # discussed => discussing + if src.endswith("ed") and dst.endswith("ing") and src[0:-2] == dst[0:-3]: + return True + # discussing => discussed + if src.endswith("ing") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: + return True + # discussion => discussing + if src.endswith("ion") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: + return True + # discussing => discussion + if src.endswith("ing") and dst.endswith("ion") and src[0:-3] == dst[0:-3]: + return True + # dispensers => dispensing + if src.endswith("ers") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: + return True + # dispensing => dispensers + if src.endswith("ing") and dst.endswith("ers") and src[0:-3] == dst[0:-3]: + return True + # discussion => discussed + if src.endswith("ion") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: + return True + # discussed => discussion + if src.endswith("ed") and dst.endswith("ion") and src[0:-2] == dst[0:-3]: + return True + # incremental => increment + if src.endswith("ntal") and dst.endswith("nt") and src[0:-4] == dst[0:-2]: + return True + # increment => incremental + if src.endswith("nt") and dst.endswith("ntal") and src[0:-2] == dst[0:-4]: + return True + # delivery => deliverer + if src.endswith("ery") and dst.endswith("erer") and src[0:-3] == dst[0:-4]: + return True + # deliverer => delivery + if src.endswith("erer") and dst.endswith("ery") and src[0:-4] == dst[0:-3]: + return True + # comparably => comparable + if src.endswith("bly") and dst.endswith("ble") and src[0:-3] == dst[0:-3]: + return True + # comparable => comparably + if src.endswith("ble") and dst.endswith("bly") and src[0:-3] == dst[0:-3]: + return True + # beautiful => beautifully + if src.endswith("l") and dst.endswith("lly") and src[0:-1] == dst[0:-3]: + return True + # beautifully => beautiful + if src.endswith("lly") and dst.endswith("l") and src[0:-3] == dst[0:-1]: + return True + # america => american + if src.endswith("a") and dst.endswith("an") and src[0:-1] == dst[0:-2]: + return True + # american => america + if src.endswith("an") and dst.endswith("a") and src[0:-2] == dst[0:-1]: + return True + # reinvesting => investing + if src.startswith("re") and src[2:] == dst: + return True + # investing => reinvesting + if dst.startswith("re") and dst[2:] == src: + return True + # outperformance => performance + if src.startswith("out") and src[3:] == dst: + return True + # performance => outperformance + if dst.startswith("out") and dst[3:] == src: + return True + return False diff --git a/nemo/collections/nlp/data/text_normalization_as_tagging/utils.py b/nemo/collections/nlp/data/text_normalization_as_tagging/utils.py index 253f7a41c703..9d5f5b7b23ad 100644 --- a/nemo/collections/nlp/data/text_normalization_as_tagging/utils.py +++ b/nemo/collections/nlp/data/text_normalization_as_tagging/utils.py @@ -17,6 +17,8 @@ from itertools import groupby from typing import Dict, List, Tuple +import numpy as np + """Utility functions for Thutmose Tagger.""" @@ -305,3 +307,197 @@ def get_src_and_dst_for_alignment( ) return written_str, spoken, " ".join(same_begin), " ".join(same_end) + + +def fill_alignment_matrix( + fline2: str, fline3: str, gline2: str, gline3: str +) -> Tuple[np.ndarray, List[str], List[str]]: + """Parse Giza++ direct and reverse alignment results and represent them as an alignment matrix + + Args: + fline2: e.g. "_2 0 1 4_" + fline3: e.g. "NULL ({ }) twenty ({ 1 }) fourteen ({ 2 3 4 })" + gline2: e.g. "twenty fourteen" + gline3: e.g. "NULL ({ }) _2 ({ 1 }) 0 ({ }) 1 ({ }) 4_ ({ 2 })" + + Returns: + matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment + the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the + words were aligned only in one direction, 0 - no alignment. + srctokens: e.g. ["twenty", "fourteen"] + dsttokens: e.g. ["_2", "0", "1", "4_"] + + For example, the alignment matrix for the above example may look like: + [[3, 0, 0, 0] + [0, 2, 2, 3]] + """ + if fline2 is None or gline2 is None or fline3 is None or gline3 is None: + raise ValueError(f"empty params") + srctokens = gline2.split() + dsttokens = fline2.split() + pattern = r"([^ ]+) \(\{ ([^\(\{\}\)]*) \}\)" + src2dst = re.findall(pattern, fline3.replace("({ })", "({ })")) + dst2src = re.findall(pattern, gline3.replace("({ })", "({ })")) + if len(src2dst) != len(srctokens) + 1: + raise ValueError( + "length mismatch: len(src2dst)=" + + str(len(src2dst)) + + "; len(srctokens)" + + str(len(srctokens)) + + "\n" + + gline2 + + "\n" + + fline3 + ) + if len(dst2src) != len(dsttokens) + 1: + raise ValueError( + "length mismatch: len(dst2src)=" + + str(len(dst2src)) + + "; len(dsttokens)" + + str(len(dsttokens)) + + "\n" + + fline2 + + "\n" + + gline3 + ) + matrix = np.zeros((len(srctokens), len(dsttokens))) + for i in range(1, len(src2dst)): + token, to_str = src2dst[i] + if to_str == "": + continue + to = list(map(int, to_str.split())) + for t in to: + matrix[i - 1][t - 1] = 2 + + for i in range(1, len(dst2src)): + token, to_str = dst2src[i] + if to_str == "": + continue + to = list(map(int, to_str.split())) + for t in to: + matrix[t - 1][i - 1] += 1 + + return matrix, srctokens, dsttokens + + +def check_monotonicity(matrix: np.ndarray) -> bool: + """Check if alignment is monotonous - i.e. the relative order is preserved (no swaps). + + Args: + matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment + the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the + words were aligned only in one direction, 0 - no alignment. + """ + is_sorted = lambda k: np.all(k[:-1] <= k[1:]) + + a = np.argwhere(matrix == 3) + b = np.argwhere(matrix == 2) + c = np.vstack((a, b)) + d = c[c[:, 1].argsort()] # sort by second column (less important) + d = d[d[:, 0].argsort(kind="mergesort")] + return is_sorted(d[:, 1]) + + +def get_targets(matrix: np.ndarray, dsttokens: List[str], delimiter: str) -> List[str]: + """Join some of the destination tokens, so that their number becomes the same as the number of input words. + Unaligned tokens tend to join to the left aligned token. + + Args: + matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment + the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the + words were aligned only in one direction, 0 - no alignment. + dsttokens: e.g. ["_2", "0", "1", "4_"] + Returns: + targets: list of string tokens, with one-to-one correspondence to matrix.shape[0] + + Example: + If we get + matrix=[[3, 0, 0, 0] + [0, 2, 2, 3]] + dsttokens=["_2", "0", "1", "4_"] + it gives + targets = ["_201", "4_"] + Actually, this is a mistake instead of ["_20", "14_"]. That will be further corrected by regular expressions. + """ + targets = [] + last_covered_dst_id = -1 + for i in range(len(matrix)): + dstlist = [] + for j in range(last_covered_dst_id + 1, len(dsttokens)): + # matrix[i][j] == 3: safe alignment point + if matrix[i][j] == 3 or ( + j == last_covered_dst_id + 1 + and np.all(matrix[i, :] == 0) # if the whole line does not have safe points + and np.all(matrix[:, j] == 0) # and the whole column does not have safe points, match them + ): + if len(targets) == 0: # if this is first safe point, attach left unaligned columns to it, if any + for k in range(0, j): + if np.all(matrix[:, k] == 0): # if column k does not have safe points + dstlist.append(dsttokens[k]) + else: + break + dstlist.append(dsttokens[j]) + last_covered_dst_id = j + for k in range(j + 1, len(dsttokens)): + if np.all(matrix[:, k] == 0): # if column k does not have safe points + dstlist.append(dsttokens[k]) + last_covered_dst_id = k + else: + break + + if len(dstlist) > 0: + targets.append(delimiter.join(dstlist)) + else: + targets.append("") + return targets + + +def get_targets_from_back(matrix: np.ndarray, dsttokens: List[str], delimiter: str) -> List[str]: + """Join some of the destination tokens, so that their number becomes the same as the number of input words. + Unaligned tokens tend to join to the right aligned token. + + Args: + matrix: a numpy array of shape (src_len, dst_len) filled with [0, 1, 2, 3], where 3 means a reliable alignment + the corresponding words were aligned to one another in direct and reverse alignment runs, 1 and 2 mean that the + words were aligned only in one direction, 0 - no alignment. + dsttokens: e.g. ["_2", "0", "1", "4_"] + Returns: + targets: list of string tokens, with one-to-one correspondence to matrix.shape[0] + + Example: + If we get + matrix=[[3, 0, 0, 0] + [0, 2, 2, 3]] + dsttokens=["_2", "0", "1", "4_"] + it gives + targets = ["_2", "014_"] + Actually, this is a mistake instead of ["_20", "14_"]. That will be further corrected by regular expressions. + """ + + targets = [] + last_covered_dst_id = len(dsttokens) + for i in range(len(matrix) - 1, -1, -1): + dstlist = [] + for j in range(last_covered_dst_id - 1, -1, -1): + if matrix[i][j] == 3 or ( + j == last_covered_dst_id - 1 and np.all(matrix[i, :] == 0) and np.all(matrix[:, j] == 0) + ): + if len(targets) == 0: + for k in range(len(dsttokens) - 1, j, -1): + if np.all(matrix[:, k] == 0): + dstlist.append(dsttokens[k]) + else: + break + dstlist.append(dsttokens[j]) + last_covered_dst_id = j + for k in range(j - 1, -1, -1): + if np.all(matrix[:, k] == 0): + dstlist.append(dsttokens[k]) + last_covered_dst_id = k + else: + break + if len(dstlist) > 0: + targets.append(delimiter.join(list(reversed(dstlist)))) + else: + targets.append("") + return list(reversed(targets)) diff --git a/nemo/collections/nlp/models/__init__.py b/nemo/collections/nlp/models/__init__.py index 90e692a238a6..75b48f64df13 100644 --- a/nemo/collections/nlp/models/__init__.py +++ b/nemo/collections/nlp/models/__init__.py @@ -30,6 +30,7 @@ from nemo.collections.nlp.models.language_modeling.transformer_lm_model import TransformerLMModel from nemo.collections.nlp.models.machine_translation import MTEncDecModel from nemo.collections.nlp.models.question_answering.qa_model import QAModel +from nemo.collections.nlp.models.spellchecking_asr_customization import SpellcheckingAsrCustomizationModel from nemo.collections.nlp.models.text2sparql.text2sparql_model import Text2SparqlModel from nemo.collections.nlp.models.text_classification import TextClassificationModel from nemo.collections.nlp.models.text_normalization_as_tagging import ThutmoseTaggerModel diff --git a/nemo/collections/nlp/models/spellchecking_asr_customization/__init__.py b/nemo/collections/nlp/models/spellchecking_asr_customization/__init__.py new file mode 100644 index 000000000000..5e94de32e9aa --- /dev/null +++ b/nemo/collections/nlp/models/spellchecking_asr_customization/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.nlp.models.spellchecking_asr_customization.spellchecking_model import ( + SpellcheckingAsrCustomizationModel, +) diff --git a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py new file mode 100644 index 000000000000..fc889de2dc63 --- /dev/null +++ b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py @@ -0,0 +1,526 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from time import perf_counter +from typing import Dict, Optional + +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +from nemo.collections.common.losses import CrossEntropyLoss +from nemo.collections.nlp.data.spellchecking_asr_customization import ( + SpellcheckingAsrCustomizationDataset, + SpellcheckingAsrCustomizationTestDataset, + TarredSpellcheckingAsrCustomizationDataset, + bert_example, +) +from nemo.collections.nlp.data.text_normalization_as_tagging.utils import read_label_map +from nemo.collections.nlp.metrics.classification_report import ClassificationReport +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.collections.nlp.modules.common.token_classifier import TokenClassifier +from nemo.collections.nlp.parts.utils_funcs import tensor2list +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.neural_types import LogitsType, NeuralType +from nemo.utils import logging +from nemo.utils.decorators import experimental + +__all__ = ["SpellcheckingAsrCustomizationModel"] + + +@experimental +class SpellcheckingAsrCustomizationModel(NLPModel): + """ + BERT-based model for Spellchecking ASR Customization. + It takes as input ASR hypothesis and candidate customization entries. + It labels the hypothesis with correct entry index or 0. + Example input: [CLS] a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o [SEP] d i d i e r _ s a u m o n [SEP] a s t r o n o m i e [SEP] t r i s t a n _ g u i l l o t [SEP] ... + Input segments: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 + Example output: 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3 0 ... + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "logits": NeuralType(('B', 'T', 'D'), LogitsType()), + } + + @property + def input_module(self): + return self + + @property + def output_module(self): + return self + + def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: + super().__init__(cfg=cfg, trainer=trainer) + + # Label map contains 11 labels: 0 for nothing, 1..10 for target candidate ids + label_map_file = self.register_artifact("label_map", cfg.label_map, verify_src_exists=True) + + # Semiotic classes for this model consist only of classes CUSTOM(means fragment containing custom candidate) and PLAIN (any other single-character fragment) + # They are used only during validation step, to calculate accuracy for CUSTOM and PLAIN classes separately + semiotic_classes_file = self.register_artifact( + "semiotic_classes", cfg.semiotic_classes, verify_src_exists=True + ) + self.label_map = read_label_map(label_map_file) + self.semiotic_classes = read_label_map(semiotic_classes_file) + + self.num_labels = len(self.label_map) + self.num_semiotic_labels = len(self.semiotic_classes) + self.id_2_tag = {tag_id: tag for tag, tag_id in self.label_map.items()} + self.id_2_semiotic = {semiotic_id: semiotic for semiotic, semiotic_id in self.semiotic_classes.items()} + self.max_sequence_len = cfg.get('max_sequence_len', self.tokenizer.tokenizer.model_max_length) + + # Setup to track metrics + # We will have (len(self.semiotic_classes) + 1) labels. + # Last one stands for WRONG (span in which the predicted tags don't match the labels) + # This is needed to feed the sequence of classes to classification_report during validation + label_ids = self.semiotic_classes.copy() + label_ids["WRONG"] = len(self.semiotic_classes) + self.tag_classification_report = ClassificationReport( + len(self.semiotic_classes) + 1, label_ids=label_ids, mode='micro', dist_sync_on_step=True + ) + + self.hidden_size = cfg.hidden_size + + # hidden size is doubled because in forward we concatenate embeddings for characters and embeddings for subwords + self.logits = TokenClassifier( + self.hidden_size * 2, num_classes=self.num_labels, num_layers=1, log_softmax=False, dropout=0.1 + ) + + self.loss_fn = CrossEntropyLoss(logits_ndim=3) + + self.builder = bert_example.BertExampleBuilder( + self.label_map, self.semiotic_classes, self.tokenizer.tokenizer, self.max_sequence_len + ) + + @typecheck() + def forward( + self, + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + ): + """ + Same BERT-based model is used to calculate embeddings for sequence of single characters and for sequence of subwords. + Then we concatenate subword embeddings to each character corresponding to this subword. + We return logits for each character x 11 labels: 0 - character doesn't belong to any candidate, 1..10 - character belongs to candidate with this id. + + # Arguments + input_ids: token_ids for single characters; .shape = [batch_size, char_seq_len]; .dtype = int64 + input_mask: mask for input_ids(1 - real, 0 - padding); .shape = [batch_size, char_seq_len]; .dtype = int64 + segment_ids: segment types for input_ids (0 - ASR-hypothesis, 1..10 - candidate); .shape = [batch_size, char_seq_len]; .dtype = int64 + input_ids_for_subwords: token_ids for subwords; .shape = [batch_size, subword_seq_len]; .dtype = int64 + input_mask_for_subwords: mask for input_ids_for_subwords(1 - real, 0 - padding); .shape = [batch_size, subword_seq_len]; .dtype = int64 + segment_ids_for_subwords: segment types for input_ids_for_subwords (0 - ASR-hypothesis, 1..10 - candidate); .shape = [batch_size, subword_seq_len]; .dtype = int64 + character_pos_to_subword_pos: tensor mapping character position in the input sequence to subword position; .shape = [batch_size, char_seq_len]; .dtype = int64 + """ + + # src_hiddens.shape = [batch_size, char_seq_len, bert_hidden_size]; .dtype=float32 + src_hiddens = self.bert_model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) + # src_hiddens_for_subwords.shape = [batch_size, subword_seq_len, bert_hidden_size]; .dtype=float32 + src_hiddens_for_subwords = self.bert_model( + input_ids=input_ids_for_subwords, + token_type_ids=segment_ids_for_subwords, + attention_mask=input_mask_for_subwords, + ) + + # Next three commands concatenate subword embeddings to each character embedding of the corresponding subword + # index.shape = [batch_size, char_seq_len, bert_hidden_size]; .dtype=int64 + index = character_pos_to_subword_pos.unsqueeze(-1).expand((-1, -1, src_hiddens_for_subwords.shape[2])) + # src_hiddens_2.shape = [batch_size, char_seq_len, bert_hidden_size]; .dtype=float32 + src_hiddens_2 = torch.gather(src_hiddens_for_subwords, 1, index) + # src_hiddens.shape = [batch_size, char_seq_len, bert_hidden_size * 2]; .dtype=float32 + src_hiddens = torch.cat((src_hiddens, src_hiddens_2), 2) + + # logits.shape = [batch_size, char_seq_len, num_labels]; num_labels=11: ids from 0 to 10; .dtype=float32 + logits = self.logits(hidden_states=src_hiddens) + return logits + + # Training + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + + ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + _, + ) = batch + logits = self.forward( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + input_ids_for_subwords=input_ids_for_subwords, + input_mask_for_subwords=input_mask_for_subwords, + segment_ids_for_subwords=segment_ids_for_subwords, + character_pos_to_subword_pos=character_pos_to_subword_pos, + ) + loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask) + lr = self._optimizer.param_groups[0]['lr'] + self.log('train_loss', loss) + self.log('lr', lr, prog_bar=True) + return {'loss': loss, 'lr': lr} + + # Validation and Testing + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop with the data from the validation dataloader + passed in as `batch`. + """ + ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + labels_mask, + labels, + spans, + ) = batch + logits = self.forward( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + input_ids_for_subwords=input_ids_for_subwords, + input_mask_for_subwords=input_mask_for_subwords, + segment_ids_for_subwords=segment_ids_for_subwords, + character_pos_to_subword_pos=character_pos_to_subword_pos, + ) + tag_preds = torch.argmax(logits, dim=2) + + # Update tag classification_report + for input_mask_seq, segment_seq, prediction_seq, label_seq, span_seq in zip( + input_mask.tolist(), segment_ids.tolist(), tag_preds.tolist(), labels.tolist(), spans.tolist() + ): + # Here we want to track whether the predicted output matches ground truth labels for each whole span. + # We construct the special input for classification report, for example: + # span_labels = [PLAIN, PLAIN, PLAIN, PLAIN, CUSTOM, CUSTOM] + # span_predictions = [PLAIN, WRONG, PLAIN, PLAIN, WRONG, CUSTOM] + # Note that the number of PLAIN and CUSTOM occurrences in the report is not comparable, + # because PLAIN is for characters, and CUSTOM is for phrases. + span_labels = [] + span_predictions = [] + plain_cid = self.semiotic_classes["PLAIN"] + wrong_cid = self.tag_classification_report.num_classes - 1 + + # First we loop through all predictions for input characters with label=0, they are regarded as separate spans with PLAIN class. + # It either stays as PLAIN if the model prediction is 0, or turns to WRONG. + for i in range(len(segment_seq)): + if input_mask_seq[i] == 0: + continue + if segment_seq[i] > 0: # token does not belong to ASR-hypothesis => it's over + break + if label_seq[i] == 0: + span_labels.append(plain_cid) + if prediction_seq[i] == 0: + span_predictions.append(plain_cid) + else: + span_predictions.append(wrong_cid) + # if label_seq[i] != 0 then it belongs to CUSTOM span and will be handled later + + # Second we loop through spans tensor which contains only spans for CUSTOM class. + # It stays as CUSTOM if all predictions for the whole span are equal to the labels, otherwise it turns to WRONG. + for cid, start, end in span_seq: + if cid == -1: + break + span_labels.append(cid) + if prediction_seq[start:end] == label_seq[start:end]: + span_predictions.append(cid) + else: + span_predictions.append(wrong_cid) + + if len(span_labels) != len(span_predictions): + raise ValueError( + "Length mismatch: len(span_labels)=" + + str(len(span_labels)) + + "; len(span_predictions)=" + + str(len(span_predictions)) + ) + self.tag_classification_report( + torch.tensor(span_predictions).to(self.device), torch.tensor(span_labels).to(self.device) + ) + + val_loss = self.loss_fn(logits=logits, labels=labels, loss_mask=labels_mask) + return {'val_loss': val_loss} + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + + # Calculate metrics and classification report + # Note that in our task recall = accuracy, and the recall column is the per class accuracy + _, tag_accuracy, _, tag_report = self.tag_classification_report.compute() + + logging.info("Total tag accuracy: " + str(tag_accuracy)) + logging.info(tag_report) + + self.log('val_loss', avg_loss, prog_bar=True) + self.log('tag accuracy', tag_accuracy) + + self.tag_classification_report.reset() + + def test_step(self, batch, batch_idx): + """ + Lightning calls this inside the test loop with the data from the test dataloader + passed in as `batch`. + """ + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs. + :param outputs: list of individual outputs of each test step. + """ + return self.validation_epoch_end(outputs) + + # Functions for inference + + @torch.no_grad() + def infer(self, dataloader_cfg: DictConfig, input_name: str, output_name: str) -> None: + """ Main function for Inference + + Args: + dataloader_cfg: config for dataloader + input_name: Input file with tab-separated text records. Each record consists of 2 items: + - ASR hypothesis + - candidate phrases separated by semicolon + output_name: Output file with tab-separated text records. Each record consists of 2 items: + - ASR hypothesis + - candidate phrases separated by semicolon + - list of possible replacements with probabilities (start, pos, candidate_id, prob), separated by semicolon + - list of labels, predicted for each letter (for debug purposes) + + Returns: None + """ + mode = self.training + device = "cuda" if torch.cuda.is_available() else "cpu" + + try: + # Switch model to evaluation mode + self.eval() + self.to(device) + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + infer_datalayer = self._setup_infer_dataloader(dataloader_cfg, input_name) + + all_tag_preds = ( + [] + ) # list(size=number of sentences) of lists(size=number of letters) of tag predictions (best candidate_id for each letter) + all_possible_replacements = ( + [] + ) # list(size=number of sentences) of lists(size=number of potential replacements) of tuples(start, pos, candidate_id, prob) + for batch in iter(infer_datalayer): + ( + input_ids, + input_mask, + segment_ids, + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + character_pos_to_subword_pos, + fragment_indices, + ) = batch + + # tag_logits.shape = [batch_size, char_seq_len, num_labels]; num_labels=11: ids from 0 to 10; .dtype=float32 + tag_logits = self.forward( + input_ids=input_ids.to(self.device), + input_mask=input_mask.to(self.device), + segment_ids=segment_ids.to(self.device), + input_ids_for_subwords=input_ids_for_subwords.to(self.device), + input_mask_for_subwords=input_mask_for_subwords.to(self.device), + segment_ids_for_subwords=segment_ids_for_subwords.to(self.device), + character_pos_to_subword_pos=character_pos_to_subword_pos.to(self.device), + ) + + # fragment_indices.shape=[batsh_size, num_fragments, 3], where last dimension is [start, end, label], where label is candidate id from 1 to 10 + # Next we want to convert predictions for separate letters to probabilities for each whole fragment from fragment_indices. + # To achieve this we first sum the letter logits in each fragment and divide by its length. + # (We use .cumsum and then difference between end and start to get sum per fragment). + # Then we convert logits to probs with softmax and for each fragment extract only the prob for given label. + # Finally we get a list of tuples (start, end, label, prob) + indices_len = fragment_indices.shape[1] + # this padding adds a row of zeros (size=num_labels) as first element of sequence in second dimension. This is needed for cumsum operations. + padded_logits = torch.nn.functional.pad(tag_logits, pad=(0, 0, 1, 0)) + ( + batch_size, + seq_len, + num_labels, + ) = padded_logits.shape # seq_len is +1 compared to that of tag_logits, because of padding + # cumsum.shape=[batch_size, seq_len, num_labels] + cumsum = padded_logits.cumsum(dim=1) + # the size -1 is inferred from other dimensions. We get rid of batch dimension. + cumsum_view = cumsum.view(-1, num_labels) + word_index = ( + torch.ones((batch_size, indices_len), dtype=torch.long) + * torch.arange(batch_size).reshape((-1, 1)) + * seq_len + ).view(-1) + lower_index = (fragment_indices[..., 0]).view(-1) + word_index + higher_index = (fragment_indices[..., 1]).view(-1) + word_index + d_index = (higher_index - lower_index).reshape((-1, 1)).to(self.device) # word lengths + dlog = cumsum_view[higher_index, :] - cumsum_view[lower_index, :] # sum of logits + # word_logits.shape=[batch_size, indices_len, num_labels] + word_logits = (dlog / d_index.float()).view(batch_size, indices_len, num_labels) + # convert logits to probs, same shape + word_probs = torch.nn.functional.softmax(word_logits, dim=-1).to(self.device) + # candidate_index.shape=[batch_size, indices_len] + candidate_index = fragment_indices[:, :, 2].to(self.device) + # candidate_probs.shape=[batch_size, indices_len] + candidate_probs = torch.take_along_dim(word_probs, candidate_index.unsqueeze(2), dim=-1).squeeze(2) + for i in range(batch_size): + possible_replacements = [] + for j in range(indices_len): + start, end, candidate_id = ( + int(fragment_indices[i][j][0]), + int(fragment_indices[i][j][1]), + int(fragment_indices[i][j][2]), + ) + if candidate_id == 0: # this is padding + continue + prob = round(float(candidate_probs[i][j]), 5) + if prob < 0.01: + continue + # -1 because in the output file we will not have a [CLS] token + possible_replacements.append( + str(start - 1) + " " + str(end - 1) + " " + str(candidate_id) + " " + str(prob) + ) + all_possible_replacements.append(possible_replacements) + + # torch.argmax(tag_logits, dim=-1) gives a tensor of best predicted labels with shape [batch_size, char_seq_len], .dtype = int64 + # character_preds is list of lists of predicted labels + character_preds = tensor2list(torch.argmax(tag_logits, dim=-1)) + all_tag_preds.extend(character_preds) + + if len(all_possible_replacements) != len(all_tag_preds) or len(all_possible_replacements) != len( + infer_datalayer.dataset.examples + ): + raise IndexError( + "number of sentences mismatch: len(all_possible_replacements)=" + + str(len(all_possible_replacements)) + + "; len(all_tag_preds)=" + + str(len(all_tag_preds)) + + "; len(infer_datalayer.dataset.examples)=" + + str(len(infer_datalayer.dataset.examples)) + ) + # save results to file + with open(output_name, "w", encoding="utf-8") as out: + for i in range(len(infer_datalayer.dataset.examples)): + hyp, ref = infer_datalayer.dataset.hyps_refs[i] + num_letters = hyp.count(" ") + 1 + tag_pred_str = " ".join(list(map(str, all_tag_preds[i][1 : (num_letters + 1)]))) + possible_replacements_str = ";".join(all_possible_replacements[i]) + out.write(hyp + "\t" + ref + "\t" + possible_replacements_str + "\t" + tag_pred_str + "\n") + + except Exception as e: + raise ValueError("Error processing file " + input_name) + + finally: + # set mode back to its original value + self.train(mode=mode) + logging.set_verbosity(logging_level) + + # Functions for processing data + def setup_training_data(self, train_data_config: Optional[DictConfig]): + if not train_data_config or not train_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the train is missing, so no data loader for train is created!" + ) + self._train_dl = None + return + self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, data_split="train") + + def setup_validation_data(self, val_data_config: Optional[DictConfig]): + if not val_data_config or not val_data_config.data_path: + logging.info( + f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!" + ) + self._validation_dl = None + return + self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, data_split="val") + + def setup_test_data(self, test_data_config: Optional[DictConfig]): + if not test_data_config or test_data_config.data_path is None: + logging.info( + f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" + ) + self._test_dl = None + return + self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config, data_split="test") + + def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str): + start_time = perf_counter() + logging.info(f'Creating {data_split} dataset') + if cfg.get("use_tarred_dataset", False): + dataset = TarredSpellcheckingAsrCustomizationDataset( + cfg.data_path, + shuffle_n=cfg.get("tar_shuffle_n", 100), + global_rank=self.global_rank, + world_size=self.world_size, + pad_token_id=self.builder._pad_id, + ) + else: + input_file = cfg.data_path + dataset = SpellcheckingAsrCustomizationDataset(input_file=input_file, example_builder=self.builder) + dl = torch.utils.data.DataLoader( + dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, collate_fn=dataset.collate_fn + ) + running_time = perf_counter() - start_time + logging.info(f'Took {running_time} seconds') + return dl + + def _setup_infer_dataloader(self, cfg: DictConfig, input_name: str) -> 'torch.utils.data.DataLoader': + """ + Setup function for a infer data loader. + Args: + cfg: config dictionary containing data loader params like batch_size, num_workers and pin_memory + input_name: path to input file. + Returns: + A pytorch DataLoader. + """ + dataset = SpellcheckingAsrCustomizationTestDataset(input_name, example_builder=self.builder) + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=cfg["batch_size"], + shuffle=False, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + return None diff --git a/scripts/dataset_processing/spoken_wikipedia/run.sh b/scripts/dataset_processing/spoken_wikipedia/run.sh index 2894eb1dc55e..5ae447c9a1a4 100644 --- a/scripts/dataset_processing/spoken_wikipedia/run.sh +++ b/scripts/dataset_processing/spoken_wikipedia/run.sh @@ -102,7 +102,7 @@ ${NEMO_PATH}/tools/ctc_segmentation/run_segmentation.sh \ --MODEL_NAME_OR_PATH=${MODEL_FOR_SEGMENTATION} \ --DATA_DIR=${INPUT_DIR}_prepared \ --OUTPUT_DIR=${OUTPUT_DIR} \ ---MIN_SCORE=${MIN_SCORE} +--MIN_SCORE=${THRESHOLD} # Thresholds for filtering CER_THRESHOLD=20 diff --git a/tests/collections/nlp/test_spellchecking_asr_customization.py b/tests/collections/nlp/test_spellchecking_asr_customization.py new file mode 100644 index 000000000000..8e4d6e9a7b8f --- /dev/null +++ b/tests/collections/nlp/test_spellchecking_asr_customization.py @@ -0,0 +1,1102 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from transformers import AutoTokenizer + +from nemo.collections.nlp.data.spellchecking_asr_customization.bert_example import BertExampleBuilder +from nemo.collections.nlp.data.spellchecking_asr_customization.utils import ( + apply_replacements_to_text, + substitute_replacements_in_text, +) + + +@pytest.mark.unit +def test_substitute_replacements_in_text(): + text = "we began the further diversification of our revenue base with the protterra supply agreement and the navastar joint development agreement" + replacements = [(66, 75, 'pro-terra', 0.99986), (101, 109, 'navistar', 0.996)] + gold_text = "we began the further diversification of our revenue base with the pro-terra supply agreement and the navistar joint development agreement" + corrected_text = substitute_replacements_in_text(text, replacements, replace_hyphen_to_space=False) + assert corrected_text == gold_text + + gold_text_no_hyphen = "we began the further diversification of our revenue base with the pro terra supply agreement and the navistar joint development agreement" + corrected_text = substitute_replacements_in_text(text, replacements, replace_hyphen_to_space=True) + assert corrected_text == gold_text_no_hyphen + + +@pytest.mark.unit +def test_apply_replacements_to_text(): + + # min_prob = 0.5 + # dp_data = None, + # min_dp_score_per_symbol: float = -99.9 + + # test more than one fragment to replace, test multiple same replacements + text = "we began the further diversification of our revenue base with the protterra supply agreement and the navastar joint development agreement" + replacements = [ + (66, 75, 'proterra', 0.99986), + (66, 75, 'proterra', 0.9956), + (101, 109, 'navistar', 0.93), + (101, 109, 'navistar', 0.91), + (101, 109, 'navistar', 0.92), + ] + gold_text = "we began the further diversification of our revenue base with the proterra supply agreement and the navistar joint development agreement" + corrected_text = apply_replacements_to_text( + text, replacements, min_prob=0.5, replace_hyphen_to_space=False, dp_data=None + ) + assert corrected_text == gold_text + + # test that min_prob works + gold_text = "we began the further diversification of our revenue base with the proterra supply agreement and the navastar joint development agreement" + corrected_text = apply_replacements_to_text( + text, replacements, min_prob=0.95, replace_hyphen_to_space=False, dp_data=None + ) + assert corrected_text == gold_text + + +@pytest.fixture() +def bert_example_builder(): + tokenizer = AutoTokenizer.from_pretrained("huawei-noah/TinyBERT_General_6L_768D") + label_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10} + semiotic_classes = {"PLAIN": 0, "CUSTOM": 1} + max_seq_len = 256 + builder = BertExampleBuilder(label_map, semiotic_classes, tokenizer, max_seq_len) + return builder + + +@pytest.mark.skip("Doesn't work download when testing on github, for unknown reason") +@pytest.mark.with_downloads +@pytest.mark.unit +def test_creation(bert_example_builder): + assert bert_example_builder._tokenizer is not None + + +@pytest.mark.skip("Doesn't work download when testing on github, for unknown reason") +@pytest.mark.with_downloads +@pytest.mark.unit +def test_builder_get_spans(bert_example_builder): + span_info_parts = ["CUSTOM 37 41", "CUSTOM 47 52", "CUSTOM 42 46", "CUSTOM 0 7"] + gold_sorted_spans = [(1, 1, 8), (1, 38, 42), (1, 43, 47), (1, 48, 53)] + spans = bert_example_builder._get_spans(span_info_parts) + spans.sort() + assert spans == gold_sorted_spans + + +@pytest.mark.skip("Doesn't work download when testing on github, for unknown reason") +@pytest.mark.with_downloads +@pytest.mark.unit +def test_builder_get_fragment_indices(bert_example_builder): + hyp = "a b o u t _ o u r _ s h i p e r s _ b u t _ y o u _ k n o w" + targets = [1] + # a b o u t _ o u r _ s h i p e r s _ b u t _ y o u _ k n o w + # 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 + span_info_parts = ["CUSTOM 8 17"] + gold_sorted_fragment_indices = [(7, 18, 1), (11, 18, 1)] + fragment_indices = bert_example_builder._get_fragment_indices(hyp, targets, span_info_parts) + fragment_indices.sort() + assert fragment_indices == gold_sorted_fragment_indices + + # a b o u t _ o u r _ s h i p e r s _ b u t _ y o u _ k n o w + # 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 + span_info_parts = ["CUSTOM 10 16"] + gold_sorted_fragment_indices = [(11, 18, 1)] + fragment_indices = bert_example_builder._get_fragment_indices(hyp, targets, span_info_parts) + fragment_indices.sort() + assert fragment_indices == gold_sorted_fragment_indices + + +@pytest.mark.skip("Doesn't work download when testing on github, for unknown reason") +@pytest.mark.with_downloads +@pytest.mark.unit +def test_builder_get_input_features(bert_example_builder): + hyp = "a s t r o n o m e r s _ d i d i e _ s o m o n _ a n d _ t r i s t i a n _ g l l o" + ref = "d i d i e r _ s a u m o n;a s t r o n o m i e;t r i s t a n _ g u i l l o t;t r i s t e s s e;m o n a d e;c h r i s t i a n;a s t r o n o m e r;s o l o m o n;d i d i d i d i d i;m e r c y" + targets = [1, 3] + span_info_parts = ["CUSTOM 12 23", "CUSTOM 28 41"] + + gold_tags = [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + ] + gold_input_ids = [ + 101, + 1037, + 1055, + 1056, + 1054, + 1051, + 1050, + 1051, + 1049, + 1041, + 1054, + 1055, + 1035, + 1040, + 1045, + 1040, + 1045, + 1041, + 1035, + 1055, + 1051, + 1049, + 1051, + 1050, + 1035, + 1037, + 1050, + 1040, + 1035, + 1056, + 1054, + 1045, + 1055, + 1056, + 1045, + 1037, + 1050, + 1035, + 1043, + 1048, + 1048, + 1051, + 102, + 1040, + 1045, + 1040, + 1045, + 1041, + 1054, + 1035, + 1055, + 1037, + 1057, + 1049, + 1051, + 1050, + 102, + 1037, + 1055, + 1056, + 1054, + 1051, + 1050, + 1051, + 1049, + 1045, + 1041, + 102, + 1056, + 1054, + 1045, + 1055, + 1056, + 1037, + 1050, + 1035, + 1043, + 1057, + 1045, + 1048, + 1048, + 1051, + 1056, + 102, + 1056, + 1054, + 1045, + 1055, + 1056, + 1041, + 1055, + 1055, + 1041, + 102, + 1049, + 1051, + 1050, + 1037, + 1040, + 1041, + 102, + 1039, + 1044, + 1054, + 1045, + 1055, + 1056, + 1045, + 1037, + 1050, + 102, + 1037, + 1055, + 1056, + 1054, + 1051, + 1050, + 1051, + 1049, + 1041, + 1054, + 102, + 1055, + 1051, + 1048, + 1051, + 1049, + 1051, + 1050, + 102, + 1040, + 1045, + 1040, + 1045, + 1040, + 1045, + 1040, + 1045, + 1040, + 1045, + 102, + 1049, + 1041, + 1054, + 1039, + 1061, + 102, + ] + gold_input_mask = [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + gold_segment_ids = [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + ] + gold_labels_mask = [ + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + gold_input_ids_for_subwords = [ + 101, + 26357, + 2106, + 2666, + 2061, + 8202, + 1998, + 13012, + 16643, + 2319, + 1043, + 7174, + 102, + 2106, + 3771, + 7842, + 2819, + 2239, + 102, + 28625, + 3630, + 9856, + 102, + 9822, + 26458, + 7174, + 2102, + 102, + 13012, + 13473, + 11393, + 102, + 13813, + 3207, + 102, + 3017, + 102, + 15211, + 102, + 9168, + 102, + 2106, + 28173, + 4305, + 4305, + 102, + 8673, + 102, + ] + gold_input_mask_for_subwords = [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ] + gold_segment_ids_for_subwords = [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 6, + 6, + 7, + 7, + 8, + 8, + 9, + 9, + 9, + 9, + 9, + 10, + 10, + ] + gold_character_pos_to_subword_pos = [ + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 3, + 3, + 3, + 4, + 4, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 8, + 8, + 8, + 9, + 9, + 9, + 10, + 11, + 11, + 11, + 12, + 13, + 13, + 13, + 14, + 14, + 14, + 14, + 15, + 15, + 16, + 16, + 17, + 17, + 18, + 19, + 19, + 19, + 19, + 19, + 20, + 20, + 21, + 21, + 21, + 22, + 23, + 23, + 23, + 23, + 23, + 23, + 23, + 23, + 24, + 24, + 24, + 25, + 25, + 25, + 26, + 27, + 28, + 28, + 28, + 29, + 29, + 29, + 30, + 30, + 30, + 31, + 32, + 32, + 32, + 32, + 33, + 33, + 34, + 35, + 35, + 35, + 35, + 35, + 35, + 35, + 35, + 35, + 36, + 37, + 37, + 37, + 37, + 37, + 37, + 37, + 37, + 37, + 37, + 38, + 39, + 39, + 39, + 39, + 39, + 39, + 39, + 40, + 41, + 41, + 41, + 42, + 42, + 42, + 43, + 43, + 44, + 44, + 45, + 46, + 46, + 46, + 46, + 46, + 47, + ] + + tags = [0 for _ in hyp.split()] + for p, t in zip(span_info_parts, targets): + c, start, end = p.split(" ") + start = int(start) + end = int(end) + tags[start:end] = [t for i in range(end - start)] + + # get input features for characters + (input_ids, input_mask, segment_ids, labels_mask, labels, _, _,) = bert_example_builder._get_input_features( + hyp=hyp, ref=ref, tags=tags + ) + + # get input features for words + hyp_with_words = hyp.replace(" ", "").replace("_", " ") + ref_with_words = ref.replace(" ", "").replace("_", " ") + ( + input_ids_for_subwords, + input_mask_for_subwords, + segment_ids_for_subwords, + _, + _, + _, + _, + ) = bert_example_builder._get_input_features(hyp=hyp_with_words, ref=ref_with_words, tags=None) + + character_pos_to_subword_pos = bert_example_builder._map_characters_to_subwords(input_ids, input_ids_for_subwords) + + assert tags == gold_tags + assert input_ids == gold_input_ids + assert input_mask == gold_input_mask + assert segment_ids == gold_segment_ids + assert labels_mask == gold_labels_mask + assert input_ids_for_subwords == gold_input_ids_for_subwords + assert input_mask_for_subwords == gold_input_mask_for_subwords + assert segment_ids_for_subwords == gold_segment_ids_for_subwords + assert character_pos_to_subword_pos == gold_character_pos_to_subword_pos diff --git a/tools/ctc_segmentation/scripts/prepare_data.py b/tools/ctc_segmentation/scripts/prepare_data.py index 429b642d5ba0..c6ea024273fb 100644 --- a/tools/ctc_segmentation/scripts/prepare_data.py +++ b/tools/ctc_segmentation/scripts/prepare_data.py @@ -151,7 +151,7 @@ def split_text( ) # end of quoted speech - to be able to split sentences by full stop - transcript = re.sub(r"([\.\?\!])([\"\'])", r"\g<2>\g<1> ", transcript) + transcript = re.sub(r"([\.\?\!])([\"\'”])", r"\g<2>\g<1> ", transcript) # remove extra space transcript = re.sub(r" +", " ", transcript) diff --git a/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb b/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb new file mode 100644 index 000000000000..189ac958d377 --- /dev/null +++ b/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb @@ -0,0 +1,1403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PiRuohn_FQco" + }, + "source": [ + "# Overview\n", + "This tutorial demonstrates how to run inference with SpellMapper - a model for Spellchecking ASR (Automatic Speech Recognition) Customization.\n", + "\n", + "Estimated time: 10-15 min.\n", + "\n", + "SpellMapper is a non-autoregressive (NAR) model based on transformer architecture ([BERT](https://arxiv.org/pdf/1810.04805.pdf) with multiple separators).\n", + "It gets as input a single ASR hypothesis (text) and a **custom vocabulary** and predicts which fragments in the ASR hypothesis should be replaced by which custom words/phrases if any.\n", + "\n", + "This model is an alternative to word boosting/shallow fusion approaches:\n", + " - does not require retraining ASR model;\n", + " - does not require beam-search/language model(LM);\n", + " - can be applied on top of any English ASR model output;" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qm5wmxVEGXgH" + }, + "source": [ + "## What is custom vocabulary?\n", + "**Custom vocabulary** is a list of words/phrases that are important for a particular user. For example, user's contact names, playlist, selected terminology and so on. The size of the custom vocabulary can vary from several hundreds to **several thousand entries** - but this is not an equivalent to ngram language model.\n", + "\n", + "![Scope of customization with user vocabulary](images/spellmapper_customization_vocabulary.png)\n", + "\n", + "Note that unlike traditional spellchecking approaches, which aim to correct known words using language models, the goal of contextual spelling correction is to correct highly specific user terms, most of which can be 1) out-of-vocabulary (OOV) words, 2) spelling variations (e.g., \"John Koehn\", \"Jon Cohen\") and language models cannot help much with that." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D5_XwuXDOKho" + }, + "source": [ + "## Tutorial Plan\n", + "\n", + "1. Create a sample custom vocabulary using some medical terminology.\n", + "2. Study what customization does - a detailed analysis of a small example.\n", + "3. Run a bigger example:\n", + " * Create sample ASR results by running TTS (text-to-speech synthesis) + ASR on some medical paper abstracts.\n", + " * Run SpellMapper inference and show how it can improve ASR results using custom vocabulary.\n", + "\n", + "TL;DR We reduce WER from `14.3%` to `11.4%` by correcting medical terms, e.g.\n", + "* `puramesin` => `puromycin`\n", + "* `parromsin` => `puromycin`\n", + "* `and hydrod` => `anhydride`\n", + "* `lesh night and` => `lesch-nyhan`\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "agz8B2CxXBBG" + }, + "source": [ + "# Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "koRPpYISNPuH" + }, + "source": [ + "## Installing NeMo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HCnnz3cgVc4Q" + }, + "outputs": [], + "source": [ + "# Install NeMo library. If you are running locally (rather than on Google Colab), comment out the below lines\n", + "# and instead follow the instructions at https://github.com/NVIDIA/NeMo#Installation\n", + "GITHUB_ACCOUNT = \"bene-ges\"\n", + "BRANCH = \"spellchecking_asr_customization_double_bert\"\n", + "!python -m pip install git+https://github.com/{GITHUB_ACCOUNT}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]\n", + "\n", + "# Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", + "# comment out the below lines and set NEMO_DIR to your local path.\n", + "NEMO_DIR = 'nemo'\n", + "!git clone -b {BRANCH} https://github.com/{GITHUB_ACCOUNT}/NeMo.git $NEMO_DIR" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_M92gCn_NW1_" + }, + "source": [ + "## Additional installs\n", + "We will use `sentence_splitter` to split abstracts to sentences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ddyJA3NtGl9C" + }, + "outputs": [], + "source": [ + "!pip install sentence_splitter" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qVa91rGkeFje" + }, + "source": [ + "Clone the SpellMapper model from HuggingFace.\n", + "Note that we will need not only the checkpoint itself, but also the ngram mapping vocabulary `replacement_vocab_filt.txt` from the same folder." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JiI9dkEm5cpW" + }, + "outputs": [], + "source": [ + "!git clone https://huggingface.co/bene-ges/spellmapper_asr_customization_en" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8saqFOePVfFf" + }, + "source": [ + "## Imports\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tAJyiYn_VnrF" + }, + "outputs": [], + "source": [ + "import IPython.display as ipd\n", + "import json\n", + "import random\n", + "import re\n", + "import soundfile as sf\n", + "import torch\n", + "\n", + "from collections import Counter, defaultdict\n", + "from difflib import SequenceMatcher\n", + "from matplotlib.pyplot import imshow\n", + "from matplotlib import pyplot as plt\n", + "from sentence_splitter import SentenceSplitter\n", + "from typing import List, Set, Tuple\n", + "\n", + "from nemo.collections.tts.models import FastPitchModel\n", + "from nemo.collections.tts.models import HifiGanModel\n", + "\n", + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest\n", + "\n", + "from nemo.collections.nlp.data.spellchecking_asr_customization.utils import (\n", + " get_all_candidates_coverage,\n", + " get_index,\n", + " load_ngram_mappings,\n", + " search_in_index,\n", + " get_candidates,\n", + " read_spellmapper_predictions,\n", + " apply_replacements_to_text,\n", + " load_ngram_mappings_for_dp,\n", + " get_alignment_by_dp,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mfAaOdAWUGUV" + }, + "source": [ + "Use seed to get a reproducible behaviour." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UlGnNKTuT_6A" + }, + "outputs": [], + "source": [ + "random.seed(0)\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RPPHI7Zd_fDz" + }, + "source": [ + "## Download data\n", + "\n", + "File `pubmed23n0009.xml` taken from public ftp server of https://www.ncbi.nlm.nih.gov/pmc/ contains information about 5593 medical papers, from which we extract only their abstracts. We will feed sentences from there to TTS + ASR to get initial ASR results.\n", + "\n", + "File `wordlist.txt` contains 100k **single-word** medical terms.\n", + "\n", + "File `valid_adam.txt` contains 24k medical abbreviations with their full forms. We will use those full forms as examples of **multi-word** medical terms.\n", + "\n", + "File `count_1w.txt` contains 330k single words with their frequencies from Google Ngrams corpus. We will use this file to filter out frequent words from our custom vocabulary.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mX6cvE8xw2n1" + }, + "outputs": [], + "source": [ + "!wget https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/pubmed23n0009.xml.gz\n", + "!gunzip pubmed23n0009.xml.gz\n", + "!grep \"AbstractText\" pubmed23n0009.xml > abstract.txt\n", + "\n", + "!wget https://raw.githubusercontent.com/McGill-NLP/medal/master/toy_data/valid_adam.txt\n", + "!wget https://raw.githubusercontent.com/glutanimate/wordlist-medicalterms-en/master/wordlist.txt\n", + "!wget https://norvig.com/ngrams/count_1w.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mBm9BeqNaRlC" + }, + "source": [ + "## Auxiliary functions\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kVUKhSh48Ypi" + }, + "outputs": [], + "source": [ + "CHARS_TO_IGNORE_REGEX = re.compile(r\"[\\.\\,\\?\\:!;()«»…\\]\\[/\\*–‽+&_\\\\½√>€™$•¼}{~—=“\\\"”″‟„]\")\n", + "\n", + "\n", + "def get_medical_vocabulary() -> Tuple[Set[str], Set[str]]:\n", + " \"\"\"This function builds a vocabulary of medical terms using downloaded sources:\n", + " wordlist.txt - 100k single-word medical terms.\n", + " valid_adam.txt - 24k medical abbreviations with their full forms. We use those full forms as examples of multi-word medical terms.\n", + " count_1w.txt - 330k single words with their frequencies from Google Ngrams corpus. We will use this file to filter out frequent words from our custom vocabulary.\n", + " \"\"\"\n", + " common_words = set()\n", + " with open(\"count_1w.txt\", \"r\", encoding=\"utf-8\") as f:\n", + " for line in f:\n", + " word, freq = line.strip().casefold().split(\"\\t\")\n", + " if int(freq) < 500000:\n", + " break\n", + " common_words.add(word)\n", + " print(\"Size of common words vocabulary:\", len(common_words))\n", + "\n", + " abbreviations = defaultdict(set)\n", + " medical_vocabulary = set()\n", + " with open(\"valid_adam.txt\", \"r\", encoding=\"utf-8\") as f:\n", + " lines = f.readlines()\n", + " # first line is header\n", + " for line in lines[1:]:\n", + " abbrev, _, phrase = line.strip().split(\"\\t\")\n", + " # skip phrases longer than 3 words because some of them are long explanations\n", + " if phrase.count(\" \") > 2:\n", + " continue\n", + " if phrase in common_words:\n", + " continue\n", + " medical_vocabulary.add(phrase)\n", + " abbrev = abbrev.lower()\n", + " abbreviations[abbrev].add(phrase)\n", + "\n", + " with open(\"wordlist.txt\", \"r\", encoding=\"utf-8\") as f:\n", + " for line in f:\n", + " word = line.strip().casefold()\n", + " # skip words contaning digits\n", + " if re.match(r\".*\\d.*\", word):\n", + " continue\n", + " if re.match(r\".*[\\[\\]\\(\\)\\+\\,\\.].*\", word):\n", + " continue\n", + " if word in common_words:\n", + " continue\n", + " medical_vocabulary.add(word)\n", + "\n", + " print(\"Size of medical vocabulary:\", len(medical_vocabulary))\n", + " print(\"Size of abbreviation vocabulary:\", len(abbreviations))\n", + " return medical_vocabulary, abbreviations\n", + "\n", + "\n", + "def read_abstracts(medical_vocabulary: Set[str]) -> Tuple[List[str], Set[str], Set[str]]:\n", + " \"\"\"This function reads the downloaded medical abstracts, and extracts sentences containing any word/phrase from the medical vocabulary.\n", + " Args:\n", + " medical_vocabulary: set of known medical words or phrases\n", + " Returns:\n", + " sentences: list of extracted sentences\n", + " all_found_singleword: set of single words from medical vocabulary that occurred at least in one sentence\n", + " all_found_multiword: set of multi-word phrases from medical vocabulary that occurred at least in one sentence\n", + " \"\"\"\n", + " splitter = SentenceSplitter(language='en')\n", + "\n", + " all_sentences = []\n", + " all_found_singleword = set()\n", + " all_found_multiword = set()\n", + " with open(\"abstract.txt\", \"r\", encoding=\"utf-8\") as f:\n", + " for line in f:\n", + " text = line.strip().replace(\"\", \"\").replace(\"\", \"\")\n", + " sents = splitter.split(text)\n", + " found_singleword = set()\n", + " found_multiword = set()\n", + " for sent in sents:\n", + " # remove anything in brackets from text\n", + " sent = re.sub(r\"\\(.+\\)\", r\"\", sent)\n", + " # remove quotes from text\n", + " sent = sent.replace(\"\\\"\", \"\")\n", + " # skip sentences contaning digits because normalization is out of scope of this tutorial\n", + " if re.match(r\".*\\d.*\", sent):\n", + " continue\n", + " # skip sentences contaning abbreviations with period inside the sentence (for the same reason)\n", + " if \". \" in sent:\n", + " continue\n", + " # skip long sentences as they may cause OOM issues\n", + " if len(sent) > 150:\n", + " continue\n", + " # replace all punctuation to space and convert to lowercase\n", + " sent_clean = CHARS_TO_IGNORE_REGEX.sub(\" \", sent).lower()\n", + " sent_clean = \" \".join(sent_clean.split(\" \"))\n", + " words = sent_clean.split(\" \")\n", + "\n", + " found_phrases = set()\n", + " for begin in range(len(words)):\n", + " for end in range(begin + 1, min(begin + 4, len(words))):\n", + " phrase = \" \".join(words[begin:end])\n", + " if phrase in medical_vocabulary:\n", + " found_phrases.add(phrase)\n", + " if end - begin == 1:\n", + " found_singleword.add(phrase)\n", + " else:\n", + " found_multiword.add(phrase)\n", + " if len(found_phrases) > 0:\n", + " all_sentences.append((sent, \";\".join(found_phrases)))\n", + " all_found_singleword = all_found_singleword.union(found_singleword)\n", + " all_found_multiword = all_found_multiword.union(found_multiword)\n", + "\n", + " print(\"Sentences:\", len(all_sentences))\n", + " print(\"Unique single-word terms found:\", len(all_found_singleword))\n", + " print(\"Unique multi-word terms found:\", len(all_found_multiword))\n", + " print(\"Examples of multi-word terms\", str(list(all_found_multiword)[0:10]))\n", + " \n", + " return all_sentences, all_found_singleword, all_found_multiword" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XU3xeCBVpWOL" + }, + "outputs": [], + "source": [ + "def get_fragments(i_words: List[str], j_words: List[str]) -> List[Tuple[str, str, str, int, int, int, int]]:\n", + " \"\"\"This function is used to compare two word sequences to find minimal fragments that differ.\n", + " Args:\n", + " i_words: list of words in first sequence\n", + " j_words: list of words in second sequence\n", + " Returns:\n", + " list of tuples (difference_type, fragment1, fragment2, begin_of_fragment1, end_of_fragment1, begin_of_fragment2, end_of_fragment2)\n", + " \"\"\"\n", + " s = SequenceMatcher(None, i_words, j_words)\n", + " result = []\n", + " for tag, i1, i2, j1, j2 in s.get_opcodes():\n", + " result.append((tag, \" \".join(i_words[i1:i2]), \" \".join(j_words[j1:j2]), i1, i2, j1, j2))\n", + " result = sorted(result, key=lambda x: x[3])\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2ydXp_pFYmYu" + }, + "source": [ + "## Read medical data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WAeauax0SV1-" + }, + "outputs": [], + "source": [ + "medical_vocabulary, _ = get_medical_vocabulary()\n", + "sentences, found_singleword, found_multiword = read_abstracts(medical_vocabulary)\n", + "# in case if we need random candidates from a big sample - we will use full medical vocabulary for that purpose.\n", + "big_sample = list(medical_vocabulary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FRli7-Kx7sOO" + }, + "outputs": [], + "source": [ + "for sent, phrases in sentences[0:10]:\n", + " print(sent, \"\\t\", phrases)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rL1VqH2_dk93" + }, + "source": [ + "# SpellMapper ASR Customization\n", + "\n", + "SpellMapper model relies on two offline preparation steps:\n", + "1. Collecting n-gram mappings from a large corpus (this mappings vocabulary had been collected once on a large corpus and is supplied with the model).\n", + "2. Indexing of user vocabulary by n-grams.\n", + "\n", + "![Offline data preparation](images/spellmapper_data_preparation.png)\n", + "\n", + "At inference time we take as input an ASR hypothesis and an n-gram-indexed user vocabulary and perform following steps:\n", + "1. Retrieve the top 10 candidate phrases from the user vocabulary that are likely to be contained in the given ASR-hypothesis, possibly in a misspelled form.\n", + "2. Run the neural model that tags the input characters with correct candidate labels or 0 if no match is found.\n", + "3. Do post-processing to combine results.\n", + "\n", + "![Inference pipeline](images/spellmapper_inference_pipeline.png)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OeJpsMwslmrd" + }, + "source": [ + "## N-gram mappings\n", + "Note that n-gram mappings vocabulary had been collected from a large corpus and is supplied with the model. It is supposed to be \"universal\" for English language.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uH6p0mOd12pi" + }, + "source": [ + "Let's see what n-gram mappings are like, for example, for an n-gram `l u c`.\n", + "Note that n-grams in `replacement_vocab_filt.txt` preserve one-to-one correspondence between original letters and misspelled fragments (this additional markup is handled during loading). \n", + "* `+` means that adjacent letters are concatenated and correspond to a single source letter. \n", + "* `` means that the original letter is deleted. \n", + "This auxiliary markup will be removed automatically during loading.\n", + "\n", + "`_` is used instead of real space symbol.\n", + "\n", + "Last three columns are:\n", + "* joint frequency\n", + "* frequency of original n-gram\n", + "* frequency of misspelled n-gram\n", + "\n", + "$$\\frac{JointFrequency}{SourceFrequency}=TranslationProbability$$\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qul163dB1sKp" + }, + "outputs": [], + "source": [ + "!awk 'BEGIN {FS=\"\\t\"} ($1==\"l u c\"){print $0}' < spellmapper_asr_customization_en/replacement_vocab_filt.txt | sort -t$'\\t' -k3nr" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eWxcrVWZ3Pfq" + }, + "source": [ + "Now we read n-gram mappings from the file. Parameter `max_misspelled_freq` controls maximum frequency of misspelled n-grams. N-grams more frequent than that are put in the list of banned n-grams and won't be used in indexing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WHKhE945-N7o" + }, + "outputs": [], + "source": [ + "print(\"load n-gram mappings...\")\n", + "ngram_mapping_vocab, ban_ngram = load_ngram_mappings(\"spellmapper_asr_customization_en/replacement_vocab_filt.txt\", max_misspelled_freq=125000)\n", + "# CAUTION: entries in ban_ngram end with a space and can contain \"+\" \"=\"\n", + "print(\"Size of ngram mapping vocabulary:\", len(ngram_mapping_vocab))\n", + "print(\"Size of banned ngrams:\", len(ban_ngram))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "49IcMBfllvXN" + }, + "source": [ + "## Indexing of custom vocabulary" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b1K6paeee2Iu" + }, + "source": [ + "As we mentioned earlier, this model pipeline is intended to work with custom vocabularies up to several thousand entries. Since the whole medical vocabulary contains 110k entries, we restrict our custom vocabulary to 5000+ terms that occured in given corpus of abstracts.\n", + "\n", + "The goal of indexing our custom vocabulary is to build an index where key is a letter n-gram and value is the whole phrase. The keys are n-grams in the given user phrase and their misspelled variants taken from our collection of n-\n", + "gram mappings (see Index of custom vocabulary in Fig. 1)\n", + "\n", + "*Though it is possible to index and search the whole 110k vocabulary, it will require additional optimizations and is beyond the scope of this tutorial.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xWb0jGqw6Woi" + }, + "outputs": [], + "source": [ + "custom_phrases = []\n", + "for phrase in medical_vocabulary:\n", + " if phrase not in found_singleword and phrase not in found_multiword:\n", + " continue\n", + " custom_phrases.append(\" \".join(list(phrase.replace(\" \", \"_\"))))\n", + "print(\"Size of customization vocabulary:\", len(custom_phrases))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UHWor5pD2Eyb" + }, + "source": [ + "Now we build the index for our custom phrases.\n", + "\n", + "Parameter `min_log_prob` controls minimum log probability, after which we stop growing this n-gram.\n", + "\n", + "Parameter `max_phrases_per_ngram` controls maximum number of phrases that can be indexed by one ngram. N-grams exceeding this limit are also banned and not used in indexing.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hs4RDXj0-xW9" + }, + "outputs": [], + "source": [ + "phrases, ngram2phrases = get_index(custom_phrases, ngram_mapping_vocab, ban_ngram, min_log_prob=-4.0, max_phrases_per_ngram=600)\n", + "print(\"Size of phrases:\", len(phrases))\n", + "print(\"Size of ngram2phrases:\", len(ngram2phrases))\n", + "\n", + "# Save index to file - later we will use it in other script\n", + "with open(\"index.txt\", \"w\", encoding=\"utf-8\") as out:\n", + " for ngram in ngram2phrases:\n", + " for phrase_id, begin, size, logprob in ngram2phrases[ngram]:\n", + " phrase = phrases[phrase_id]\n", + " out.write(ngram + \"\\t\" + phrase + \"\\t\" + str(begin) + \"\\t\" + str(size) + \"\\t\" + str(logprob) + \"\\n\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RV1sdQ9rvar8" + }, + "source": [ + "## Small detailed example\n", + "\n", + "Let's consider, for example, one custom phrase `thoracic aorta` and an incorrect ASR-hypothesis `the tarasic oorda is a part of the aorta located in the thorax`, containing a misspelled phrase `tarasic_oorda`. \n", + "\n", + "We will see \n", + "1. How this custom phrase is indexed.\n", + "2. How candidate retrieval works, given ASR-hypothesis.\n", + "3. How inference and post-processing work.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kGBTTJXixnrG" + }, + "source": [ + "### N-grams in index" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ryfUlqNMl4vQ" + }, + "source": [ + "Let's look, for example, by what n-grams a custom phrase `thoracic aorta` is indexed. \n", + "Columns: \n", + "1. n-gram\n", + "2. beginning position in the phrase\n", + "3. length\n", + "4. log probability\n", + "\n", + "Note that many n-grams are not from n-gram mappings file. Those are derived by growing previous n-grams with new replacements. In this case log probabilities are summed up. Growing stops, when minimum log prob is exceeded.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x0ZVsXGBo8pt" + }, + "outputs": [], + "source": [ + "for ngram in ngram2phrases:\n", + " for phrase_id, b, length, lprob in ngram2phrases[ngram]:\n", + " if phrases[phrase_id] == \"t h o r a c i c _ a o r t a\":\n", + " print(ngram.ljust(16) + \"\\t\" + str(b).rjust(4) + \"\\t\" + str(length).rjust(4) + \"\\t\" + str(lprob))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "20ov23ze4xeQ" + }, + "source": [ + "### Candidate retrieval\n", + "Candidate retrieval tasks are:\n", + " - Given an input sentence and an index of custom vocabulary find all n-grams from the index matching the sentence. \n", + " - Find which sentence fragments and which custom phrases have most \"hits\" - potential candidates.\n", + " - Find approximate starting position for each candidate phrase. \n", + "\n", + "\n", + "Let's look at the hits, that phrase \"thoracic aorta\" gets by searching all ngrams in the input text. We can see some hits in different part of the sentence, but a moving window can find a fragment with most hits." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t_rhKQ3Xqa8A" + }, + "outputs": [], + "source": [ + "sent = \"the_tarasic_oorda_is_a_part_of_the_aorta_located_in_the_thorax\"\n", + "phrases2positions, position2ngrams = search_in_index(ngram2phrases, phrases, sent)\n", + "print(\" \".join(list(sent)))\n", + "print(\" \".join(list(map(str, phrases2positions[phrases.index(\"t h o r a c i c _ a o r t a\")].astype(int)))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "orkRapbjF4aZ" + }, + "source": [ + "`phrases2positions` is a matrix of size (len(phrases), len(ASR_hypothesis)).\n", + "It is filled with 1.0 (hits) on intersection of letter n-grams and phrases that are indexed by these n-grams, 0.0 - elsewhere.\n", + "It is used to find phrases with many hits within a contiguous window - potential matching candidates.\n", + "\n", + "`position2ngrams` is a list of sets of ngrams. List index is the starting position in the ASR-hypothesis.\n", + "It is used later to check how well each found candidate is covered by n-grams (to avoid cases where some repeating n-gram gives many hits to a phrase, but the phrase itself is not well covered)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JF7u4_iiHLyI" + }, + "outputs": [], + "source": [ + "candidate2coverage, candidate2position = get_all_candidates_coverage(phrases, phrases2positions)\n", + "print(\"Coverage=\", candidate2coverage[phrases.index(\"t h o r a c i c _ a o r t a\")])\n", + "print(\"Starting position=\", candidate2position[phrases.index(\"t h o r a c i c _ a o r t a\")])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "45mvKg8ZyNbr" + }, + "source": [ + "`candidate2coverage` is a list of size len(phrases) containing coverage (0.0 to 1.0) in best window.\n", + "Coverage is a smoothed percentage of hits in the window of size of the given phrase.\n", + "\n", + "`candidate2position` is a list of size len(phrases) containing starting position of best window.\n", + "\n", + "Starting position is approximate, it's ok. If it is not at the beginning of some word, SpellMapper will try to adjust it later. In this particular example we get 5 as starting position instead of 4, missing the first letter." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sjyn9I98udL9" + }, + "source": [ + "### Inference\n", + "\n", + "Now let's generate input for SpellMapper inference. \n", + "An input line should consist of 4 tab-separated columns:\n", + " - text of ASR-hypothesis\n", + " - texts of 10 candidates separated by semicolon\n", + " - 1-based ids of non-dummy candidates\n", + " - approximate start/end coordinates of non-dummy candidates (correspond to ids)\n", + "Note that candidate retrieval is done inside the function `get_candidates`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cJnusVfBRhRX" + }, + "outputs": [], + "source": [ + "out = open(\"spellmapper_input.txt\", \"w\", encoding=\"utf-8\")\n", + "letters = list(sent)\n", + "candidates = get_candidates(ngram2phrases, phrases, letters, big_sample)\n", + "# We add two columns with targets and span_info. \n", + "# They have same format as during training, but start and end positions are APPROXIMATE, they will be adjusted when constructing BertExample.\n", + "targets = []\n", + "span_info = []\n", + "for idx, c in enumerate(candidates):\n", + " if c[1] == -1:\n", + " continue\n", + " targets.append(str(idx + 1)) # targets are 1-based\n", + " start = c[1]\n", + " end = min(c[1] + c[2], len(letters)) # ensure that end is not outside sentence length (it can happen because c[2] is candidate length used as approximation)\n", + " span_info.append(\"CUSTOM \" + str(start) + \" \" + str(end))\n", + "\n", + "out.write(\" \".join(letters) + \"\\t\" + \";\".join([x[0] for x in candidates]) + \"\\t\" + \" \".join(targets) + \"\\t\" + \";\".join(span_info) + \"\\n\")\n", + "out.close()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qpei5o89SmaU" + }, + "outputs": [], + "source": [ + "!cat spellmapper_input.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9rAmO15SS6go" + }, + "outputs": [], + "source": [ + "!python nemo/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \\\n", + " pretrained_model=spellmapper_asr_customization_en/training_10m_5ep.nemo \\\n", + " model.max_sequence_len=512 \\\n", + " inference.from_file=spellmapper_input.txt \\\n", + " inference.out_file=spellmapper_output.txt \\\n", + " inference.batch_size=16 \\\n", + " lang=en\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wd2aq4T1N5cs" + }, + "source": [ + "Each line in SpellMapper output is tab-separated and consists of 4 columns:\n", + "1. ASR-hypothesis (same as in input)\n", + "2. 10 candidates separated with semicolon (same as in input)\n", + "3. fragment predictions, separated with semicolon, each prediction is a tuple (start, end, candidate_id, probability)\n", + "4. letter predictions - candidate_id predicted for each letter (this is only for debug purposes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ravgEX8cTFty" + }, + "outputs": [], + "source": [ + "!cat spellmapper_output.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "az26364-PHb2" + }, + "source": [ + "We can use some utility functions to apply found replacements and get actual corrected text." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lPtFa_EhK8pb" + }, + "outputs": [], + "source": [ + "spellmapper_results = read_spellmapper_predictions(\"spellmapper_output.txt\")\n", + "text, replacements, _ = spellmapper_results[0]\n", + "corrected_text = apply_replacements_to_text(text, replacements, replace_hyphen_to_space=False)\n", + "print(\"Text before correction:\\n\", text)\n", + "print(\"Text after correction:\\n\", corrected_text)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "efF7O-D91FLX" + }, + "source": [ + "# Bigger customization example\n", + "\n", + "Let's test customization on more data. The plan is\n", + " * Get baseline ASR transcriptions by running TTS + ASR on some medical paper abstracts.\n", + " * Run SpellMapper inference and show how it can improve ASR results using custom vocabulary.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r_EFPnyDcXZt" + }, + "source": [ + "## Run TTS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "i9F5SBhmr8rk" + }, + "outputs": [], + "source": [ + "# create a folder for wav files (TTS output)\n", + "!rm -r audio\n", + "!mkdir audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JMbkNVt7YBAO" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "# Load FastPitch from HuggingFace\n", + "spectrogram_generator = FastPitchModel.from_pretrained(\"nvidia/tts_en_fastpitch\").eval().to(device)\n", + "# Load HifiGan vocoder from HuggingFace\n", + "vocoder = HifiGanModel.from_pretrained(model_name=\"nvidia/tts_hifigan\").eval().to(device)\n", + "\n", + "# Write sentences that we want to feed to TTS\n", + "with open(\"tts_input.txt\", \"w\", encoding=\"utf-8\") as out:\n", + " for sent, _ in sentences[0:100]:\n", + " out.write(sent + \"\\n\")\n", + "\n", + "out_manifest = open(\"manifest.json\", \"w\", encoding=\"utf-8\")\n", + "i = 0\n", + "with open(\"tts_input.txt\", \"r\", encoding=\"utf-8\") as inp:\n", + " for line in inp:\n", + " text = line.strip()\n", + " text_clean = CHARS_TO_IGNORE_REGEX.sub(\" \", text).lower() #replace all punctuation to space and convert to lowercase\n", + " text_clean = \" \".join(text_clean.split())\n", + "\n", + " parsed = spectrogram_generator.parse(text, normalize=True)\n", + "\n", + " spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)\n", + " audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n", + "\n", + " # Note that vocoder return a batch of audio. In this example, we just take the first and only sample.\n", + " filename = \"audio/\" + str(i) + \".wav\"\n", + " sf.write(filename, audio.to('cpu').detach().numpy()[0], 16000)\n", + " out_manifest.write(\n", + " \"{\\\"audio_filepath\\\": \\\"\" + filename + \"\\\", \\\"text\\\": \\\"\" + text_clean + \"\\\", \\\"orig_text\\\": \\\"\" + text + \"\\\"}\\n\"\n", + " )\n", + " i += 1\n", + "\n", + " # display some examples\n", + " if i < 10:\n", + " print(f'\"{text}\"\\n')\n", + " ipd.display(ipd.Audio(audio.to('cpu').detach(), rate=22050))\n", + "\n", + "out_manifest.close()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9T3CZcCAmxCz" + }, + "source": [ + "Now we have a folder with generated audios `audio/*.wav` and a nemo manifest with json records like `{\"audio_filepath\": \"audio/0.wav\", \"text\": \"no renal auditory or vestibular toxicity was observed\", \"orig_text\": \"No renal, auditory, or vestibular toxicity was observed.\"}`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pR_T1HnttVjm" + }, + "outputs": [], + "source": [ + "lines = []\n", + "with open(\"manifest.json\", \"r\", encoding=\"utf-8\") as f:\n", + " lines = f.readlines()\n", + "\n", + "for line in lines:\n", + " try:\n", + " data = json.loads(line.strip())\n", + " except:\n", + " print(line)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Free GPU memory to avoid OOM." + ], + "metadata": { + "id": "bt2TMLLvdUHm" + } + }, + { + "cell_type": "code", + "source": [ + "del spectrogram_generator\n", + "del vocoder\n", + "torch.cuda.empty_cache()" + ], + "metadata": { + "id": "ZwEpAOCaRH7s" + }, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HrensakWdLkt" + }, + "source": [ + "## Run baseline ASR" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IQNIo2M_mqJc" + }, + "source": [ + "Next we transcribe our .wav files with a general domain [ASR model](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_conformer_ctc_large). It will generate an output file `ctc_baseline_transcript.json` where the predicted transcriptions are stored in the field `pred_text` of each record.\n", + "\n", + "Note that this ASR model was not trained or fine-tuned on medical domain, so we expect it to make mistakes on medical terms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NMN63ux1mJiG" + }, + "outputs": [], + "source": [ + "!python nemo/examples/asr/transcribe_speech.py \\\n", + " pretrained_name=\"stt_en_conformer_ctc_large\" \\\n", + " dataset_manifest=manifest.json \\\n", + " output_filename=ctc_baseline_transcript_tmp.json \\\n", + " batch_size=2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L3swQ8uqqgnp" + }, + "source": [ + "ATTENTION: SpellMapper relies on words to be separated by _single_ space\n", + "\n", + "There is a bug with multiple space, observed in ASR results produced by Conformer-CTC, probably connected to this issue: https://github.com/NVIDIA/NeMo/issues/4034.\n", + "\n", + "So we need to correct the manifests to ensure that all spaces are single." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z17sxkmXrXpJ" + }, + "outputs": [], + "source": [ + "test_data = read_manifest(\"ctc_baseline_transcript_tmp.json\")\n", + "\n", + "for i in range(len(test_data)):\n", + " # if there are multiple spaces in the string they will be merged to one\n", + " test_data[i][\"pred_text\"] = \" \".join(test_data[i][\"pred_text\"].split())\n", + "\n", + "with open(\"ctc_baseline_transcript.json\", \"w\", encoding=\"utf-8\") as out:\n", + " for d in test_data:\n", + " line = json.dumps(d)\n", + " out.write(line + \"\\n\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PuKtfhbVkVJY" + }, + "outputs": [], + "source": [ + "!head -n 4 ctc_baseline_transcript.json" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aCJw9NEXqRg8" + }, + "source": [ + "### Calculating WER of baseline transcript\n", + "We use the standard script from NeMo to calculate WER and CER of our baseline transcript. Internally it compares the text in `pred_text` (predicted transcript) to `text` (reference transcript). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZmNEGVWQsGo2" + }, + "outputs": [], + "source": [ + "!python nemo/examples/asr/speech_to_text_eval.py \\\n", + " dataset_manifest=ctc_baseline_transcript.json \\\n", + " only_score_manifest=True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AvPwJr0ZqdkN" + }, + "source": [ + "### See fragments that differ\n", + "We use SequenceMatcher to see fragments that differ. (Another option is to use a more powerful analytics tool [Speech Data Explorer](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/tools/speech_data_explorer.html))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RAeaVCpMv78y" + }, + "outputs": [], + "source": [ + "test_data = read_manifest(\"ctc_baseline_transcript.json\")\n", + "pred_text = [data['pred_text'] for data in test_data]\n", + "ref_text = [data['text'] for data in test_data]\n", + "audio_filepath = [data['audio_filepath'] for data in test_data]\n", + "\n", + "diff_vocab = Counter()\n", + "\n", + "for i in range(len(test_data)):\n", + " ref_sent = \" \" + ref_text[i] + \" \"\n", + " pred_sent = \" \" + pred_text[i] + \" \"\n", + "\n", + " pred_words = pred_sent.strip().split()\n", + " ref_words = ref_sent.strip().split()\n", + "\n", + " for tag, hyp_fragment, ref_fragment, i1, i2, j1, j2 in get_fragments(pred_words, ref_words):\n", + " if tag != \"equal\":\n", + " diff_vocab[(tag, hyp_fragment, ref_fragment)] += 1\n", + "\n", + "sum_ = 0\n", + "print(\"PRED vs REF\")\n", + "for k, v in diff_vocab.most_common(1000000):\n", + " sum_ += v\n", + " print(k, v, \"sum=\", sum_)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dUSOF7iD1w_9" + }, + "source": [ + "## Run SpellMapper" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x39BQhYB6_Fr" + }, + "source": [ + "Now we run retrieval on our input manifest and prepare input for SpellMapper inference. Note that we use index of custom vocabulary (file `index.txt` that we saved earlier)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y8x-yT5WqfFz" + }, + "outputs": [], + "source": [ + "!python nemo/examples/nlp/spellchecking_asr_customization/prepare_input_from_manifest.py \\\n", + " --manifest ctc_baseline_transcript.json \\\n", + " --custom_vocab_index index.txt \\\n", + " --big_sample spellmapper_asr_customization_en/big_sample.txt \\\n", + " --short2full_name short2full.txt \\\n", + " --output_name spellmapper_input.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ueq_JAPWGs_Y" + }, + "source": [ + "Run the inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zgkqiiZtJjcB" + }, + "outputs": [], + "source": [ + "!python nemo/examples/nlp/spellchecking_asr_customization/spellchecking_asr_customization_infer.py \\\n", + " pretrained_model=spellmapper_asr_customization_en/training_10m_5ep.nemo \\\n", + " model.max_sequence_len=512 \\\n", + " inference.from_file=spellmapper_input.txt \\\n", + " inference.out_file=spellmapper_output.txt \\\n", + " inference.batch_size=16 \\\n", + " lang=en\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RPQWJX8dFLfX" + }, + "source": [ + "Now we postprocess SpellMapper output and create output corrected manifest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3eFU515yKvXP" + }, + "outputs": [], + "source": [ + "!python nemo/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py \\\n", + " --input_manifest ctc_baseline_transcript.json \\\n", + " --short2full_name short2full.txt \\\n", + " --output_manifest ctc_corrected_transcript.json \\\n", + " --spellmapper_result spellmapper_output.txt \\\n", + " --replace_hyphen_to_space \\\n", + " --field_name pred_text \\\n", + " --ngram_mappings \"\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hRoIhhGh17tp" + }, + "source": [ + "### Calculating WER of corrected transcript." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qIT957bGo9AY" + }, + "outputs": [], + "source": [ + "!python nemo/examples/asr/speech_to_text_eval.py \\\n", + " dataset_manifest=ctc_corrected_transcript.json \\\n", + " only_score_manifest=True\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NYXIPusupqOQ" + }, + "outputs": [], + "source": [ + "test_data = read_manifest(\"ctc_corrected_transcript.json\")\n", + "pred_text = [data['pred_text'] for data in test_data]\n", + "ref_text = [data['pred_text_before_correction'] for data in test_data]\n", + "\n", + "diff_vocab = Counter()\n", + "\n", + "for i in range(len(test_data)):\n", + " ref_sent = \" \" + ref_text[i] + \" \"\n", + " pred_sent = \" \" + pred_text[i] + \" \"\n", + "\n", + " pred_words = pred_sent.strip().split()\n", + " ref_words = ref_sent.strip().split()\n", + "\n", + " for tag, hyp_fragment, ref_fragment, i1, i2, j1, j2 in get_fragments(pred_words, ref_words):\n", + " if tag != \"equal\":\n", + " diff_vocab[(tag, hyp_fragment, ref_fragment)] += 1\n", + "\n", + "sum_ = 0\n", + "print(\"Corrected vs baseline\")\n", + "for k, v in diff_vocab.most_common(1000000):\n", + " sum_ += v\n", + " print(k, v, \"sum=\", sum_)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DJtXlqXbTD6M" + }, + "source": [ + "### Filtering by Dynamic Programming(DP) score\n", + "\n", + "What else can be done?\n", + "Given a fragment and its potential replacement, we can apply **dynamic programming** to find the most probable \"translation\" path between them. We will use the same n-gram mapping vocabulary, because its frequencies give us \"translation probability\" of each n-gram pair. The final path score can be calculated as maximum sum of log probalities of matching n-grams along this path.\n", + "Let's look at an example. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "05Qf9wgHU_UR" + }, + "outputs": [], + "source": [ + "joint_vocab, orig_vocab, misspelled_vocab, max_len = load_ngram_mappings_for_dp(\"spellmapper_asr_customization_en/replacement_vocab_filt.txt\")\n", + "\n", + "fragment = \"and hydrod\"\n", + "replacement = \"anhydride\"\n", + "fragment_spaced = \" \".join(list(fragment.replace(\" \", \"_\")))\n", + "replacement_spaced = \" \".join(list(replacement.replace(\" \", \"_\")))\n", + "path = get_alignment_by_dp(\n", + " replacement_spaced,\n", + " fragment_spaced,\n", + " dp_data=(joint_vocab, orig_vocab, misspelled_vocab, max_len)\n", + ")\n", + "print(\"Dynamic Programming path:\")\n", + "for fragment_ngram, replacement_ngram, score, sum_score, joint_freq, orig_freq, misspelled_freq in path:\n", + " print(\n", + " \"\\t\",\n", + " \"frag=\",\n", + " fragment_ngram,\n", + " \"; repl=\",\n", + " replacement_ngram,\n", + " \"; score=\",\n", + " score,\n", + " \"; sum_score=\",\n", + " sum_score,\n", + " \"; joint_freq=\",\n", + " joint_freq,\n", + " \"; orig_freq=\",\n", + " orig_freq,\n", + " \"; misspelled_freq=\",\n", + " misspelled_freq,\n", + " )\n", + "\n", + "print(\"Final path score is in path[-1][3]: \", path[-1][3])\n", + "print(\"Dynamic programming(DP) score per symbol is final score divided by len(fragment): \", path[-1][3] / (len(fragment)))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hgfKPKckaLnc" + }, + "source": [ + "The idea is that we can skip replacements whose average DP score per symbol is below some predefined minimum, say -1.5.\n", + "Note that dynamic programming works slow because of quadratic complexity, but it allows to get rid of some false positives. Let's apply it on the same test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UhSXh7ht_JRn" + }, + "outputs": [], + "source": [ + "!python nemo/examples/nlp/spellchecking_asr_customization/postprocess_and_update_manifest.py \\\n", + " --input_manifest ctc_baseline_transcript.json \\\n", + " --short2full_name short2full.txt \\\n", + " --output_manifest ctc_corrected_transcript_dp.json \\\n", + " --spellmapper_result spellmapper_output.txt \\\n", + " --replace_hyphen_to_space \\\n", + " --field_name pred_text \\\n", + " --use_dp \\\n", + " --ngram_mappings spellmapper_asr_customization_en/replacement_vocab_filt.txt \\\n", + " --min_dp_score_per_symbol -1.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u8R5YHB3vPC8" + }, + "outputs": [], + "source": [ + "!python nemo/examples/asr/speech_to_text_eval.py \\\n", + " dataset_manifest=ctc_corrected_transcript_dp.json \\\n", + " only_score_manifest=True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "upvTbkFAeYtR" + }, + "source": [ + "# Final notes\n", + "1. Our paper...\n", + "\n", + "2. To reproduce evaluation experiments from this paper see these scripts:\n", + " - [test_on_kensho.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)\n", + " - [test_on_userlibri.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)\n", + " - [test_on_spoken_wikipedia.sh](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/evaluation/test_on_kensho.sh)\n", + "\n", + "3. To reproduce training see [README.md](https://github.com/bene-ges/nemo_compatible/blob/main/scripts/nlp/en_spellmapper/README.md)\n", + "\n", + "4. Promising future research directions would be:\n", + " - add a simple trainable classifier on top of SpellMapper predictions instead of using multiple thresholds\n", + " - retrain with adding more various false positives to the training data" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "toc_visible": true, + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/tutorials/nlp/images/spellmapper_customization_vocabulary.png b/tutorials/nlp/images/spellmapper_customization_vocabulary.png new file mode 100644 index 000000000000..1ecd7ab5add5 Binary files /dev/null and b/tutorials/nlp/images/spellmapper_customization_vocabulary.png differ diff --git a/tutorials/nlp/images/spellmapper_data_preparation.png b/tutorials/nlp/images/spellmapper_data_preparation.png new file mode 100644 index 000000000000..24df8a8e0525 Binary files /dev/null and b/tutorials/nlp/images/spellmapper_data_preparation.png differ diff --git a/tutorials/nlp/images/spellmapper_inference_pipeline.png b/tutorials/nlp/images/spellmapper_inference_pipeline.png new file mode 100644 index 000000000000..07d85d2e2295 Binary files /dev/null and b/tutorials/nlp/images/spellmapper_inference_pipeline.png differ