Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt formatter API and canary transcribe tensor input support #9206

Merged
merged 22 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bfbacdc
Apply CanaryPromptFormatter in dataset/inference
pzelasko May 15, 2024
967776b
Working inference with CanaryPromptFormatter
pzelasko May 15, 2024
04fdba9
Minimum working example of Canary.transcribe() with tensors
pzelasko May 15, 2024
1f902ff
training fix
pzelasko May 15, 2024
e86362e
Update to the new 'chat' based prompt formatting API
pzelasko May 21, 2024
41f96f1
Prompt formatters for popular models and partial unit test coverage
pzelasko May 21, 2024
06ff96d
Updated documentation
pzelasko May 21, 2024
71b9191
Improved test coverage + proper preamble support
pzelasko May 22, 2024
5555a4c
Fix usage of PromptFormatter for MT-AED class + fix tokenization/form…
pzelasko May 22, 2024
2350356
Move some canary hacks to canary prompt formatter, improve validation…
pzelasko May 22, 2024
30713b8
aed_model.transcribe(**slots) support, rename all slots to lowercase …
pzelasko May 23, 2024
9334a88
truly generic version
pzelasko May 23, 2024
2f7cd7a
making transcribe_speech.py work prompt slots + syntactic sugar
pzelasko May 23, 2024
3a533ae
update streaming_utils.py
pzelasko May 23, 2024
d6f75f0
Merge branch 'main' into prompt-formatter-and-canary-tensor-dataset
pzelasko May 23, 2024
61f92d8
fix
pzelasko May 23, 2024
9fe28cb
code review: partial
pzelasko May 24, 2024
3f60244
Accept multi-turn, single-turn, and legacy prompt format in transcrib…
pzelasko May 29, 2024
9e13c2e
Address code reviews
pzelasko May 31, 2024
3f9453b
Add support for SPE special tokens bos/eos in prompt templates and en…
pzelasko May 31, 2024
55ac422
Fix tests and add llama2 prompt formatter tests
pzelasko May 31, 2024
43ec9ad
Fix tests
pzelasko May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import glob
import json
import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass, field, is_dataclass
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union

Expand All @@ -25,6 +25,7 @@
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel
from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
Expand Down Expand Up @@ -169,6 +170,14 @@ class TranscriptionConfig:

# Decoding strategy for AED models
multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig()
# Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs:
# Implicit single-turn assuming default role='user' (works with Canary-1B)
# +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes
# Explicit single-turn prompt:
# +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes
# Explicit multi-turn prompt:
# +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]'
prompt: dict = field(default_factory=dict)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
Expand Down Expand Up @@ -411,6 +420,8 @@ def autocast(dtype=None):
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
if hasattr(override_cfg, "prompt"):
override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt))
transcriptions = asr_model.transcribe(
audio=filepaths,
override_config=override_cfg,
Expand Down
45 changes: 24 additions & 21 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _speech_collate_fn(batch, pad_id):
has_audio = audio_lengths[0] is not None
if has_audio:
max_audio_len = max(audio_lengths).item()
max_tokens_len = max(tokens_lengths).item()
has_tokens = tokens_lengths[0] is not None
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
if has_tokens:
max_tokens_len = max(tokens_lengths).item()

audio_signal, tokens = [], []
for b in batch:
Expand All @@ -89,19 +91,24 @@ def _speech_collate_fn(batch, pad_id):
pad = (0, max_audio_len - sig_len)
sig = torch.nn.functional.pad(sig, pad)
audio_signal.append(sig)
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)
if has_tokens:
tokens_i_len = tokens_i_len.item()
if tokens_i_len < max_tokens_len:
pad = (0, max_tokens_len - tokens_i_len)
tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
tokens.append(tokens_i)

if has_audio:
audio_signal = torch.stack(audio_signal)
audio_lengths = torch.stack(audio_lengths)
else:
audio_signal, audio_lengths = None, None
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
if has_tokens:
tokens = torch.stack(tokens)
tokens_lengths = torch.stack(tokens_lengths)
else:
tokens = None
tokens_lengths = None
if sample_ids is None:
return audio_signal, audio_lengths, tokens, tokens_lengths
else:
Expand Down Expand Up @@ -256,8 +263,7 @@ def cache_datastore_manifests(
if num_datastore_manifests > 0:
# Local utility function
def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
"""Cache manifests and audio data from object store.
"""
"""Cache manifests and audio data from object store."""
# Determine the number of workers to use
if num_workers is None:
num_workers = os.cpu_count() - 1
Expand Down Expand Up @@ -421,8 +427,7 @@ class _AudioTextDataset(Dataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -546,8 +551,7 @@ class AudioToCharDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -640,8 +644,7 @@ class AudioToBPEDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -910,8 +913,7 @@ def __next__(self):
return TarredAudioFilter(self.manifest_processor.collection)

def _loop_offsets(self, iterator):
"""This function is used to iterate through utterances with different offsets for each file.
"""
"""This function is used to iterate through utterances with different offsets for each file."""

class TarredAudioLoopOffsets:
def __init__(self, collection):
Expand Down Expand Up @@ -944,8 +946,7 @@ def _collate_fn(self, batch):
return _speech_collate_fn(batch, self.pad_id)

def _build_sample(self, tup):
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
"""
"""Builds the training sample by combining the data from the WebDataset with the manifest info."""
audio_bytes, audio_filename, offset_id = tup

# Grab manifest entry from self.manifest_preprocessor.collection
Expand Down Expand Up @@ -1316,7 +1317,9 @@ class BucketingDataset(IterableDataset):
"""

def __init__(
self, dataset: IterableDataset, bucketing_batch_size: int,
self,
dataset: IterableDataset,
bucketing_batch_size: int,
):
self.wrapped_dataset = dataset
self.bucketing_batch_size = bucketing_batch_size
Expand Down
153 changes: 45 additions & 108 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
from typing import Callable, Sequence

import omegaconf
import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut, MonoCut
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -57,21 +58,21 @@ def __init__(
def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
audio, audio_lens, cuts = self.load_audio(cuts)

tokens, prompt_tokens = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)
prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompt_tokens = [torch.as_tensor(t) for t in prompt_tokens]
prompt_token_lens = torch.tensor([t.size(0) for t in prompt_tokens], dtype=torch.long)
prompt_tokens = collate_vectors(prompt_tokens, padding_value=self.padding_value)
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompt_tokens = None
prompt_token_lens = None
prompts = None
prompts_lens = None

return audio, audio_lens, tokens, token_lens, prompt_tokens, prompt_token_lens
return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens


# Mapping from a string name to a known prompt formatter function.
Expand Down Expand Up @@ -105,7 +106,9 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool]

pzelasko marked this conversation as resolved.
Show resolved Hide resolved

@registered_prompt_format_fn
def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -> Sequence[Sequence[int]]:
def canary(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.

Expand All @@ -132,116 +135,50 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False) -
assert isinstance(
tokenizer._tokenizer, CanaryTokenizer
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
tokenizer = tokenizer._tokenizer
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

tokens, prompts = [], []
prompts_with_answers, prompts = [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
assert isinstance(cut, MonoCut), "Expected MonoCut."
Copy link
Collaborator

@stevehuang52 stevehuang52 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better change to raising TypeError and saying something like "expected input audio to have single channel", since users might not know what "MonoCut" means

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


# first, validate the utterance
missing_keys = [k for k in ("source_lang", "target_lang", "taskname", "pnc") if k not in cut.custom]
expected_slots = set(formatter.get_slots("user"))
missing_keys = expected_slots - set(cut.custom)
if "task" in missing_keys and "taskname" in cut.custom:
# Compatibility with "old" Canary manifest format.
# For compatbility with inference options, this slot is now called "task".
cut.custom["task"] = cut.custom["taskname"]
missing_keys.remove("task")
if missing_keys:
raise RuntimeError(
f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# Actual tokenization. If a cut has multiple supervisions, we'll stitch their tokenized texts together.
texts = [sup.text for sup in cut.supervisions]
langs = [sup.language for sup in cut.supervisions]
taskname = cut.custom['taskname']
pnc = cut.custom['pnc']
source_lang = cut.custom['source_lang']
target_lang = cut.custom['target_lang']

tokens.append(canary_prompt(tokenizer, texts, langs, source_lang, target_lang, taskname, pnc))
if inference:
prompts.append(canary_prompt(tokenizer, None, None, source_lang, target_lang, taskname, pnc))
return tokens, prompts


def canary_prompt(
tokenizer: CanaryTokenizer,
text: str | list[str] | None,
language: str | list[str] | None,
source_language: str,
target_language: str,
taskname: str,
pnc: str,
) -> list[int]:
if isinstance(text, str):
text = [text]
if isinstance(language, str):
language = [language]

if text is not None:
try:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
except omegaconf.errors.KeyValidationError as e:
raise ProbablyIncorrectLanguageKeyError(
"We couldn't select the right tokenizer, which could be due to issues with reading "
"the language from the manifest. "
"If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). "
"If you're using model.transcribe() directly, please use override_config kwarg to set this. "
"If you're using transcribe_speech.py, use option gt_lang_attr_name='...' "
) from e
else:
tokens = None # create prompt for inference

# bos
prompted_tokens = [tokenizer.bos_id]

if tokens is not None and len(tokens) == 0:
# no speech token
prompted_tokens.append(tokenizer.nospeech_id)
else:
# first, validate the utterance
if source_language is None or target_language is None or taskname is None or pnc is None:
raise RuntimeError(
f"Missing keys provided to prompt: "
f"source_langauge={source_language},\n"
f"target_language={target_language},\n"
f"taskname={taskname},\n"
f"pnc={pnc}\n"
f"Please ensure that every utterance in the input manifests contains these keys."
)

# src_lang_id/no_speech
src_lang_id = tokenizer.spl_token_to_id(source_language)
prompted_tokens.append(src_lang_id)

# task
task = taskname
if task == 'asr' or task == "transcribe":
prompted_tokens.append(tokenizer.spl_token_to_id("transcribe"))
elif task == 's2t_translation' or task == 'ast' or task == "translate":
prompted_tokens.append(tokenizer.spl_token_to_id("translate"))
else:
raise ValueError(f"Unknown task: {task}")

# tgt_lang_id
tgt_lang_id = tokenizer.spl_token_to_id(target_language)
prompted_tokens.append(tgt_lang_id)

# PnC
pnc = f"{pnc}".lower().strip() # to account for bool or str
if pnc in {'yes', 'true'}:
prompted_tokens.append(tokenizer.spl_token_to_id("pnc"))
elif pnc in {'no', 'false'}:
prompted_tokens.append(tokenizer.spl_token_to_id("nopnc"))
else:
raise ValueError(f"Unknown value for key 'pnc': {pnc}")

# text (only in training)
if tokens is not None:
prompted_tokens.extend(tokens)
encoded = formatter.encode_dialog(
turns=[
dict(
role="user",
slots={
**{slot: cut.custom[slot] for slot in expected_slots},
formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER,
},
),
dict(
role="assistant",
slots={
"text": ' '.join(s.text for s in cut.supervisions),
formatter.PROMPT_LANGUAGE_SLOT: cut.custom["target_lang"],
},
),
]
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])

# eos (only in training)
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens
return prompts_with_answers, prompts


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
Loading
Loading