From 4917ac8bab2e5dc0021f17249f58b7a827a83af9 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Mon, 18 Mar 2024 11:43:29 +0100 Subject: [PATCH] allow export of onnx-streaming-models with other than 80dim input features (#1556) --- egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 6bc9b18586..6320e51ca8 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -333,6 +333,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + feature_dim: int = 80, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -343,7 +344,7 @@ def export_encoder_model_onnx( # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling T = decode_chunk_len + encoder_model.pad_length - x = torch.rand(1, T, 80, dtype=torch.float32) + x = torch.rand(1, T, feature_dim, dtype=torch.float32) init_state = encoder_model.get_init_states() num_encoders = len(encoder_model.encoder.encoder_dim) logging.info(f"num_encoders: {num_encoders}") @@ -724,6 +725,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + feature_dim=params.feature_dim, ) logging.info(f"Exported encoder to {encoder_filename}")