Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid conformer export #6983

Merged
merged 14 commits into from
Jul 7, 2023
10 changes: 10 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ It is recommended to train a model in streaming model with limited context for t

You may find FastConformer variants of cache-aware streaming models under ``<NeMo_git_root>/examples/asr/conf/fastconformer/``.

Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`

.. _LSTM-Transducer_model:

LSTM-Transducer
Expand Down Expand Up @@ -291,6 +296,11 @@ Similar example configs for FastConformer variants of Hybrid models can be found
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_transducer_ctc/``
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/``

Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`

.. _Conformer-HAT_model:

Conformer-HAT (Hybrid Autoregressive Transducer)
Expand Down
31 changes: 31 additions & 0 deletions docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,37 @@ Another common requirement for models that are being exported is to run certain
# call base method for common set of modifications
Exportable._prepare_for_export(self, **kwargs)

Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets.
To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder'].

.. code-block:: Python

def get_export_subnet(self, subnet=None):
"""
Returns Exportable subnet model/module to export
"""


def list_export_subnets(self):
"""
Returns default set of subnet names exported for this model
First goes the one receiving input (input_example)
"""

Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export:

.. code-block:: Python
def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one.
An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_models.py``.

Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc

Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(args)
14 changes: 14 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
self.finalize_interctc_metrics(metrics, outputs, prefix="test_")
return metrics

# EncDecRNNTModel is exported in 2 parts
def list_export_subnets(self):
if self.cur_decoder == 'rnnt':
return ['encoder', 'decoder_joint']
else:
return ['self']

@property
def output_module(self):
if self.cur_decoder == 'rnnt':
return self.decoder
else:
return self.ctc_decoder

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
Expand Down
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
Expand All @@ -39,7 +39,7 @@
from nemo.utils import logging


class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable):
class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel):
"""Base class for encoder decoder RNNT-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -960,6 +960,14 @@ def list_export_subnets(self):
def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

def set_export_config(self, args):
if 'decoder_type' in args:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=args['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(args)

@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
Expand Down
14 changes: 14 additions & 0 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,17 @@ def list_export_subnets(self):
First goes the one receiving input (input_example)
"""
return ['self']

def get_export_config(self):
"""
Returns export_config dictionary
"""
return getattr(self, 'export_config', {})

def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
ex_config = self.get_export_config()
ex_config.update(args)
self.export_config = ex_config
19 changes: 15 additions & 4 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def get_args(argv):
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)

args = parser.parse_args(argv)
return args

Expand Down Expand Up @@ -130,10 +139,12 @@ def nemo_export(argv):
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim

if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"):
model.encoder.export_cache_support = True
logging.info("Caching support is enabled.")
model.encoder.setup_streaming_params()
if args.cache_support:
model.set_export_config({"cache_support": "True"})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
model.set_export_config(kv)

autocast = nullcontext
if args.autocast:
Expand Down