From 968ebd236b4a03c95421d47dfb673aa718028080 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 27 Jun 2023 14:35:59 +0800 Subject: [PATCH] Fix ONNX export of the latest streaming zipformer model. (#1148) --- egs/librispeech/ASR/zipformer/export-onnx-streaming.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 80dc19b37a..ff3e46433f 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -86,7 +86,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import str2bool def get_parser(): @@ -218,7 +218,7 @@ def forward( ) assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( @@ -272,6 +272,7 @@ def get_init_states( states = self.encoder.get_init_states(batch_size, device) embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)