Skip to content

Commit

Permalink
Canary: inference tokenization improvements; preserving custom keys w…
Browse files Browse the repository at this point in the history
…hen creating tarred manifests (#8432)

* Improvements for Canary:

- carry over custom keys when creatin tarred manifests
- selectable text field in ASR eval
- get rid of prompt slicing, create proper inference prompts

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* set ensure_ascii=False in tarred conversion to avoid breaking tokenizers trained on UTF-8 encoding

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
pzelasko authored Feb 16, 2024
1 parent 87589a7 commit 5c1b8d1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 68 deletions.
16 changes: 8 additions & 8 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
for full list of arguments >>
dataset_manifest: Required - path to dataset JSON manifest file (in NeMo format)
output_filename: Optional - output filename where the transcriptions will be written. (if scores_per_sample=True,
output_filename: Optional - output filename where the transcriptions will be written. (if scores_per_sample=True,
metrics per sample will be written there too)
use_cer: Bool, whether to compute CER or WER
use_punct_er: Bool, compute dataset Punctuation Error Rate (set the punctuation marks for metrics computation with
use_punct_er: Bool, compute dataset Punctuation Error Rate (set the punctuation marks for metrics computation with
"text_processing.punctuation_marks")
tolerance: Float, minimum WER/CER required to pass some arbitrary tolerance.
only_score_manifest: Bool, when set will skip audio transcription and just calculate WER of provided manifest.
Expand Down Expand Up @@ -141,13 +141,13 @@ def main(cfg: EvaluationConfig):
for line in f:
data = json.loads(line)

if 'pred_text' not in data:
if "pred_text" not in data:
invalid_manifest = True
break

ground_truth_text.append(data['text'])
ground_truth_text.append(data[cfg.gt_text_attr_name])

predicted_text.append(data['pred_text'])
predicted_text.append(data["pred_text"])

pc = PunctuationCapitalization(cfg.text_processing.punctuation_marks)
if cfg.text_processing.separate_punctuation:
Expand Down Expand Up @@ -183,7 +183,7 @@ def main(cfg: EvaluationConfig):

samples_with_metrics = compute_metrics_per_sample(
manifest_path=cfg.dataset_manifest,
reference_field="text",
reference_field=cfg.gt_text_attr_name,
hypothesis_field="pred_text",
metrics=metrics_to_compute,
punctuation_marks=cfg.text_processing.punctuation_marks,
Expand All @@ -207,7 +207,7 @@ def main(cfg: EvaluationConfig):

logging.info(f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}')

logging.info(f'Dataset WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%")
logging.info(f"Dataset WER/CER {wer:.2%}/{cer:.2%}")

if cfg.use_punct_er:
dper_obj.print()
Expand Down
48 changes: 34 additions & 14 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,22 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
"""

def __init__(
self, tokenizer: TokenizerSpec, prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]
self,
tokenizer: TokenizerSpec,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]],
inference: bool = False,
):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer._tokenizer.pad_id
self.prompt_format_fn = prompt_format_fn
self.inference = inference

def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
audio, audio_lens, cuts = self.load_audio(cuts)

tokens = self.prompt_format_fn(cuts, self.tokenizer)
tokens = self.prompt_format_fn(cuts, self.tokenizer, self.inference)
tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
Expand All @@ -64,7 +68,7 @@ def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.T
PROMPT_FORMAT_FNS = {}


def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]):
def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]):
"""
Decorator for registering prompt functions under a name.
Expand All @@ -82,7 +86,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper],
return prompt_fn


def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]:
def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]:
if name not in PROMPT_FORMAT_FNS:
raise ValueError(
f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}"
Expand All @@ -91,7 +95,7 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequ


@registered_prompt_format_fn
def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]:
def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -> Sequence[Sequence[int]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.
Expand Down Expand Up @@ -135,8 +139,11 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]
)

# Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together.
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
if not inference:
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
else:
texts, langs = None, None
taskname = cut.custom['taskname']
pnc = cut.custom['pnc']
source_lang = cut.custom['source_lang']
Expand All @@ -149,18 +156,29 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> Sequence[Sequence[int]]
return canary_tokens


def canary_prompt(tokenizer: CanaryTokenizer, text, language, source_language, target_language, taskname, pnc):
def canary_prompt(
tokenizer: CanaryTokenizer,
text: str | list[str] | None,
language: str | list[str] | None,
source_language: str,
target_language: str,
taskname: str,
pnc: str,
) -> list[int]:
if isinstance(text, str):
text = [text]
if isinstance(language, str):
language = [language]

tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
if text is not None:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
else:
tokens = None # create prompt for inference

# bos
prompted_tokens = [tokenizer.bos_id]

if len(tokens) == 0:
if tokens is not None and len(tokens) == 0:
# no speech token
prompted_tokens.append(tokenizer.nospeech_id)
else:
Expand Down Expand Up @@ -201,9 +219,11 @@ def canary_prompt(tokenizer: CanaryTokenizer, text, language, source_language, t
else:
raise ValueError(f"Unknown value for key 'pnc': {pnc}")

# text
prompted_tokens.extend(tokens)
# text (only in training)
if tokens is not None:
prompted_tokens.extend(tokens)

# eos
prompted_tokens.append(tokenizer.eos_id)
# eos (only in training)
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens
33 changes: 11 additions & 22 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
tokenizer=self.tokenizer,
)

self.context_len_for_AR_decoding = self.cfg.get("context_len_for_AR_decoding", 5)

# Define autoregressive CE loss
with open_dict(self.cfg.loss):
self.cfg.loss.pad_id = self.tokenizer.pad_id
Expand Down Expand Up @@ -442,7 +440,7 @@ def transcribe(

return super().transcribe(audio=audio, override_config=trcfg)

def _setup_dataloader_from_config(self, config: Optional[Dict]):
def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False):
assert config.get("use_lhotse", False), (
"Multi-task model only supports dataloading with Lhotse. "
"Please set config.{train,validation,test}_ds.use_lhotse=True"
Expand All @@ -452,7 +450,9 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
global_rank=self.global_rank,
world_size=self.world_size,
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn(self.prompt_format),
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
inference=inference,
),
)

Expand Down Expand Up @@ -496,7 +496,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict

# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
"""
Expand All @@ -512,7 +512,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):

# preserve config
self._update_dataset_config(dataset_name='test', config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True)

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
Expand Down Expand Up @@ -660,9 +660,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
beam_hypotheses = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=input_ids[:, : self.context_len_for_AR_decoding]
if self.context_len_for_AR_decoding > 0
else None,
decoder_input_ids=input_ids,
return_hypotheses=False,
)[0]

Expand Down Expand Up @@ -839,14 +837,7 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch[0], input_signal_length=batch[1]
)

decoder_input_ids = (
batch[2][:, : self.context_len_for_AR_decoding].to(trcfg._internal.device)
if self.context_len_for_AR_decoding > 0
else None
)
# decoder_input_ids = None

decoder_input_ids = batch[2].to(trcfg._internal.device)
output = dict(
log_probs=log_probs,
encoded_lengths=encoded_len,
Expand Down Expand Up @@ -881,7 +872,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
best_hypotheses, all_hypotheses = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=decoder_input_ids if self.context_len_for_AR_decoding > 0 else None,
decoder_input_ids=decoder_input_ids,
return_hypotheses=trcfg.return_hypotheses,
)

Expand Down Expand Up @@ -933,7 +924,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'lang_field': 'target_lang',
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
return temporary_datalayer

def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig):
Expand Down Expand Up @@ -1022,9 +1013,7 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa
text = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=transcript[:, : self.context_len_for_AR_decoding]
if self.context_len_for_AR_decoding > 0
else None,
decoder_input_ids=transcript,
return_hypotheses=False,
)[0]

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __iter__(self) -> Generator[Cut, None, None]:
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data[self.text_field],
text=data.get(self.text_field),
language=data.get(self.lang_field),
)
)
Expand Down Expand Up @@ -257,7 +257,7 @@ def __iter__(self) -> Generator[Cut, None, None]:
recording_id=cut.recording_id,
start=0,
duration=cut.duration,
text=data[self.text_field],
text=data.get(self.text_field),
language=data.get(self.lang_field),
)
)
Expand Down
26 changes: 7 additions & 19 deletions scripts/speech_recognition/convert_to_tarred_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/",
new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json')
with open(new_manifest_shard_path, 'w', encoding='utf-8') as m2:
for entry in manifest:
json.dump(entry, m2)
json.dump(entry, m2, ensure_ascii=False)
m2.write('\n')

# Flatten the list of list of entries to a list of entries
Expand All @@ -377,7 +377,7 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/",
new_manifest_path = os.path.join(target_dir, 'tarred_audio_manifest.json')
with open(new_manifest_path, 'w', encoding='utf-8') as m2:
for entry in new_entries:
json.dump(entry, m2)
json.dump(entry, m2, ensure_ascii=False)
m2.write('\n')

# Write metadata (default metadata for new datasets)
Expand Down Expand Up @@ -555,7 +555,7 @@ def create_concatenated_dataset(
new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json')
with open(new_manifest_shard_path, 'w', encoding='utf-8') as m2:
for entry in manifest:
json.dump(entry, m2)
json.dump(entry, m2, ensure_ascii=False)
m2.write('\n')

# Flatten the list of list of entries to a list of entries
Expand All @@ -574,12 +574,12 @@ def create_concatenated_dataset(
with open(new_manifest_path, 'w', encoding='utf-8') as m2:
# First write all the entries of base manifest
for entry in base_entries:
json.dump(entry, m2)
json.dump(entry, m2, ensure_ascii=False)
m2.write('\n')

# Finally write the new entries
for entry in new_entries:
json.dump(entry, m2)
json.dump(entry, m2, ensure_ascii=False)
m2.write('\n')

# Preserve historical metadata
Expand Down Expand Up @@ -679,24 +679,12 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder):
to_write = base + "-sub" + str(count[squashed_filename]) + ext
count[squashed_filename] += 1

# Carry over every key in the entry, override audio_filepath and shard_id
new_entry = {
**entry,
'audio_filepath': to_write,
'duration': entry['duration'],
'shard_id': shard_id, # Keep shard ID for recordkeeping
}

if 'label' in entry:
new_entry['label'] = entry['label']

if 'text' in entry:
new_entry['text'] = entry['text']

if 'offset' in entry:
new_entry['offset'] = entry['offset']

if 'lang' in entry:
new_entry['lang'] = entry['lang']

new_entries.append(new_entry)

tar.close()
Expand Down
Loading

0 comments on commit 5c1b8d1

Please sign in to comment.