Skip to content

Commit

Permalink
Fix ONNX export of the latest streaming zipformer model. (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jun 27, 2023
1 parent 219bba1 commit 968ebd2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from icefall.utils import str2bool


def get_parser():
Expand Down Expand Up @@ -218,7 +218,7 @@ def forward(
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)

src_key_padding_mask = make_pad_mask(x_lens)
src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)

# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
Expand Down Expand Up @@ -272,6 +272,7 @@ def get_init_states(
states = self.encoder.get_init_states(batch_size, device)

embed_states = self.encoder_embed.get_init_states(batch_size, device)

states.append(embed_states)

processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
Expand Down

0 comments on commit 968ebd2

Please sign in to comment.