Skip to content

Commit

Permalink
Timestamps to transcribe (#10950)
Browse files Browse the repository at this point in the history
* inital version

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Support for RNNT, TDT, Hybrid Models

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* move change of decoder stratery from mixin to individual model class

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* update transcribe_speech.py

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* uncomment

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* add docs

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* fix docs

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* codeql fixes

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* unit tests

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* minor rebase fix

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* add None case to restore the state set outside using decoding_stratergy()

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* remove ipdb traces

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* updates doc for transcription.py

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* remove preserve alignment for AED models as it doesn;t support it

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* lint warnings

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>
  • Loading branch information
2 people authored and yashaswikarnati committed Nov 21, 2024
1 parent ffddec9 commit 1661d21
Show file tree
Hide file tree
Showing 20 changed files with 623 additions and 253 deletions.
35 changes: 32 additions & 3 deletions docs/source/asr/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,39 @@ After :ref:`installing NeMo<installation>`, you can transcribe an audio file as
asr_model = nemo_asr.models.ASRModel.from_pretrained("stt_en_fastconformer_transducer_large")
transcript = asr_model.transcribe(["path/to/audio_file.wav"])
Obtain word/segment timestamps
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Obtain timestamps
^^^^^^^^^^^^^^^^^

You can also obtain timestamps for each word or segment in the transcription as follows:
Obtaining char(token), word or segment timestamps is also possible with NeMo ASR Models.

Currently, timestamps are available for Parakeet Models with all types of decoders (CTC/RNNT/TDT). Support for AED models would be added soon.

There are two ways to obtain timestamps:
1. By using the `timestamps=True` flag in the `transcribe` method.
2. For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps.

With the `timestamps=True` flag, you can obtain timestamps for each character in the transcription as follows:

.. code-block:: python
# import nemo_asr and instantiate asr_model as above
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt_ctc-110m")
# specify flag `timestamps=True`
hypotheses = asr_model.transcribe(["path/to/audio_file.wav"], timestamps=True)
# by default, timestamps are enabled for char, word and segment level
word_timestamps = hypotheses[0][0].timestep['word'] # word level timestamps for first sample
segment_timestamps = hypotheses[0][0].timestep['segment'] # segment level timestamps
char_timestamps = hypotheses[0][0].timestep['char'] # char level timestamps
for stamp in segment_timestamps:
print(f"{stamp['start']}s - {stamp['end']}s : {stamp['segment']}")
# segment level timestamps (if model supports Punctuation and Capitalization, segment level timestamps are displayed based on punctuation otherwise complete transcription is considered as a single segment)
For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps as follows:

.. code-block:: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

"""
This script chunks long audios into non-overlapping segments of `chunk_len_in_secs` seconds and performs inference on each
This script chunks long audios into non-overlapping segments of `chunk_len_in_secs`
seconds and performs inference on each
segment individually. The results are then concatenated to form the final output.
Below is an example of how to run this script with the Canary-1b model.
It's recommended to use manifest input, otherwise the model will perform English ASR with punctuations and capitalizations.
It's recommended to use manifest input, otherwise the model will perform English ASR
with punctuations and capitalizations.
An example manifest line:
{
"audio_filepath": "/path/to/audio.wav", # path to the audio file
Expand All @@ -41,11 +43,10 @@
"""

import contextlib
import copy
import glob
import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from typing import Optional

import pytorch_lightning as pl
Expand All @@ -67,6 +68,10 @@

@dataclass
class TranscriptionConfig:
"""
Transcription config
"""

# Required configs
model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
Expand Down Expand Up @@ -116,6 +121,10 @@ class TranscriptionConfig:

@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
"""
Transcribes the input audio and can be used to infer long audio files by chunking
them into smaller segments.
"""
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -160,7 +169,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:

if model_cfg.preprocessor.normalize != "per_feature":
logging.error(
"Only EncDecMultiTaskModel models trained with per_feature normalization are supported currently"
"Only EncDecMultiTaskModel models trained with per_feature normalization are supported \
currently"
)

# Disable config overwriting
Expand Down Expand Up @@ -206,7 +216,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
)

output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@
You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
predictions of the model, and ground-truth text if presents in manifest.
"""
import contextlib
import copy
import glob
import math
import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from typing import Optional

import pytorch_lightning as pl
Expand All @@ -65,6 +64,10 @@

@dataclass
class TranscriptionConfig:
"""
Transcription Configuration for buffered inference.
"""

# Required configs
model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
Expand Down Expand Up @@ -114,6 +117,10 @@ class TranscriptionConfig:

@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
"""
Transcribes the input audio and can be used to infer long audio files by chunking
them into smaller segments.
"""
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -221,7 +228,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
filepaths,
)
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import glob
import math
import os
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from typing import Optional

import pytorch_lightning as pl
Expand All @@ -87,6 +87,10 @@

@dataclass
class TranscriptionConfig:
"""
Transcription Configuration for buffered inference.
"""

# Required configs
model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
Expand Down Expand Up @@ -143,6 +147,10 @@ class TranscriptionConfig:

@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
"""
Transcribes the input audio and can be used to infer long audio files by chunking
them into smaller segments.
"""
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -274,7 +282,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
)

output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

Expand Down
12 changes: 10 additions & 2 deletions examples/asr/speech_translation/translate_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import json
import os
from dataclasses import dataclass, is_dataclass
Expand Down Expand Up @@ -65,13 +64,19 @@

@dataclass
class ModelChangeConfig:
"""
Sub-config for changes specific to the Conformer Encoder
"""

# Sub-config for changes specific to the Conformer Encoder
conformer: ConformerChangeConfig = ConformerChangeConfig()


@dataclass
class TranslationConfig:
"""
Translation Configuration for audio to text translation.
"""

# Required configs
model_path: Optional[str] = None # Path to a .nemo file
pretrained_name: Optional[str] = None # Name of a pretrained model
Expand Down Expand Up @@ -106,6 +111,9 @@ class TranslationConfig:

@hydra_runner(config_name="TranslationConfig", schema=TranslationConfig)
def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]:
"""
Main function to translate audio to text using a pretrained/finetuned model.
"""
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

for key in cfg:
Expand Down
Loading

0 comments on commit 1661d21

Please sign in to comment.