Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid conformer export #6983

Merged
merged 14 commits into from
Jul 7, 2023
29 changes: 29 additions & 0 deletions docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ Another common requirement for models that are being exported is to run certain
# call base method for common set of modifications
Exportable._prepare_for_export(self, **kwargs)

Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets.
To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder'].

.. code-block:: Python

def get_export_subnet(self, subnet=None):
"""
Returns Exportable subnet model/module to export
"""


def list_export_subnets(self):
"""
Returns default set of subnet names exported for this model
First goes the one receiving input (input_example)
"""

Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - set_export_config() method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export. Also, if an action hook on setting config is desired, this method may be overloaded to include one.

.. code-block:: Python
def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""

Here is example on now set_export_config() is being tied to command line arguments in scripts/export.py :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=[ctc|rnnt]

Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(args)
14 changes: 14 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
self.finalize_interctc_metrics(metrics, outputs, prefix="test_")
return metrics

# EncDecRNNTModel is exported in 2 parts
def list_export_subnets(self):
if self.cur_decoder == 'rnnt':
return ['encoder', 'decoder_joint']
else:
return ['self']

@property
def output_module(self):
if self.cur_decoder == 'rnnt':
return self.decoder
else:
return self.ctc_decoder

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
Expand Down
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
Expand All @@ -39,7 +39,7 @@
from nemo.utils import logging


class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable):
class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel):
"""Base class for encoder decoder RNNT-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -960,6 +960,14 @@ def list_export_subnets(self):
def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

def set_export_config(self, args):
if 'decoder_type' in args:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=args['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(args)

@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
Expand Down
14 changes: 14 additions & 0 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,17 @@ def list_export_subnets(self):
First goes the one receiving input (input_example)
"""
return ['self']

def get_export_config(self):
"""
Returns export_config dictionary
"""
return getattr(self, 'export_config', {})

def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
ex_config = self.get_export_config()
ex_config.update(args)
self.export_config = ex_config
19 changes: 15 additions & 4 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def get_args(argv):
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)

args = parser.parse_args(argv)
return args

Expand Down Expand Up @@ -130,10 +139,12 @@ def nemo_export(argv):
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim

if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"):
model.encoder.export_cache_support = True
logging.info("Caching support is enabled.")
model.encoder.setup_streaming_params()
if args.cache_support:
model.set_export_config({"cache_support": "True"})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
model.set_export_config(kv)

autocast = nullcontext
if args.autocast:
Expand Down