Skip to content

Commit

Permalink
allow export of onnx-streaming-models with other than 80dim input fea…
Browse files Browse the repository at this point in the history
…tures (#1556)
  • Loading branch information
KarelVesely84 committed Mar 18, 2024
1 parent eec12f0 commit 4917ac8
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
feature_dim: int = 80,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
Expand All @@ -343,7 +344,7 @@ def export_encoder_model_onnx(
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length

x = torch.rand(1, T, 80, dtype=torch.float32)
x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
Expand Down Expand Up @@ -724,6 +725,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
feature_dim=params.feature_dim,
)
logging.info(f"Exported encoder to {encoder_filename}")

Expand Down

0 comments on commit 4917ac8

Please sign in to comment.