diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 80dc19b37a..ff3e46433f 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -86,7 +86,7 @@ find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import str2bool def get_parser(): @@ -218,7 +218,7 @@ def forward( ) assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( @@ -272,6 +272,7 @@ def get_init_states( states = self.encoder.get_init_states(batch_size, device) embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)