From caddcc7dd815365cd4a2b9aa13bc2dbe9d2a3628 Mon Sep 17 00:00:00 2001 From: Greg Clark Date: Fri, 5 May 2023 15:11:53 -0400 Subject: [PATCH] More streaming conformer export fixes (#6567) Signed-off-by: Greg Clark Co-authored-by: Vahid Noroozi --- .../asr/modules/conformer_encoder.py | 29 +++++++++++++++++++ nemo/core/classes/exportable.py | 16 +++++++--- scripts/export.py | 1 + 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 0fc0912a8921..9955e35444f4 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -183,6 +183,19 @@ def input_types(self): } ) + @property + def input_types_for_export(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + @property def output_types(self): """Returns definitions of module output ports.""" @@ -196,6 +209,19 @@ def output_types(self): } ) + @property + def output_types_for_export(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True), + "cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True), + "cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True), + } + ) + @property def disabled_deployment_input_names(self): if not self.export_cache_support: @@ -489,6 +515,8 @@ def forward_for_export( rets = self.streaming_post_process(rets, keep_all_outputs=False) if len(rets) == 2: return rets + elif rets[2] is None and rets[3] is None and rets[4] is None: + return (rets[0], rets[1]) else: return ( rets[0], @@ -549,6 +577,7 @@ def forward_internal( audio_signal = self.pre_encode(audio_signal) else: audio_signal, length = self.pre_encode(x=audio_signal, lengths=length) + length = length.to(torch.int64) # self.streaming_cfg is set by setup_streaming_cfg(), called in the init if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None: audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :] diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index eb399b1c1d1d..38b8e1c1e31b 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -215,8 +215,8 @@ def _export( elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: - dynamic_axes = get_dynamic_axes(self.input_module.input_types, input_names) - dynamic_axes.update(get_dynamic_axes(self.output_module.output_types, output_names)) + dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) + dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) torch.onnx.export( jitted_model, input_example, @@ -273,11 +273,19 @@ def _export_teardown(self): @property def input_names(self): - return get_io_names(self.input_module.input_types, self.disabled_deployment_input_names) + return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names) @property def output_names(self): - return get_io_names(self.output_module.output_types, self.disabled_deployment_output_names) + return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names) + + @property + def input_types_for_export(self): + return self.input_types + + @property + def output_types_for_export(self): + return self.output_types def get_export_subnet(self, subnet=None): """ diff --git a/scripts/export.py b/scripts/export.py index efb257d00447..80cbcf3dc666 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -32,6 +32,7 @@ import torch from pytorch_lightning import Trainer +import nemo from nemo.core import ModelPT from nemo.core.classes import Exportable from nemo.core.config.pytorch_lightning import TrainerConfig