Skip to content

Commit

Permalink
support gpu decoding for mask-ctc
Browse files Browse the repository at this point in the history
  • Loading branch information
YosukeHiguchi committed May 18, 2022
1 parent 49100e4 commit 9c83ddb
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion espnet2/asr/maskctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]:
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)

y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
y_in = (
torch.zeros(1, len(y_idx), dtype=torch.long).to(enc_out.device)
+ self.mask_token
)
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]

logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
Expand Down

0 comments on commit 9c83ddb

Please sign in to comment.