Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More streaming conformer export fixes #6578

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,19 @@ def input_types(self):
}
)

@property
def input_types_for_export(self):
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

@property
def output_types(self):
"""Returns definitions of module output ports."""
Expand All @@ -196,6 +209,19 @@ def output_types(self):
}
)

@property
def output_types_for_export(self):
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

@property
def disabled_deployment_input_names(self):
if not self.export_cache_support:
Expand Down Expand Up @@ -489,6 +515,8 @@ def forward_for_export(
rets = self.streaming_post_process(rets, keep_all_outputs=False)
if len(rets) == 2:
return rets
elif rets[2] is None and rets[3] is None and rets[4] is None:
return (rets[0], rets[1])
else:
return (
rets[0],
Expand Down Expand Up @@ -549,6 +577,7 @@ def forward_internal(
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
length = length.to(torch.int64)
# self.streaming_cfg is set by setup_streaming_cfg(), called in the init
if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
Expand Down
16 changes: 12 additions & 4 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def _export(
elif format == ExportFormat.ONNX:
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
dynamic_axes = get_dynamic_axes(self.input_module.input_types, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types, output_names))
dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
torch.onnx.export(
jitted_model,
input_example,
Expand Down Expand Up @@ -273,11 +273,19 @@ def _export_teardown(self):

@property
def input_names(self):
return get_io_names(self.input_module.input_types, self.disabled_deployment_input_names)
return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names)

@property
def output_names(self):
return get_io_names(self.output_module.output_types, self.disabled_deployment_output_names)
return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names)

@property
def input_types_for_export(self):
return self.input_types

@property
def output_types_for_export(self):
return self.output_types

def get_export_subnet(self, subnet=None):
"""
Expand Down
1 change: 1 addition & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch
from pytorch_lightning import Trainer

import nemo
from nemo.core import ModelPT
from nemo.core.classes import Exportable
from nemo.core.config.pytorch_lightning import TrainerConfig
Expand Down