From 49100e4f1b3fc389c5672dc2ca17973525c4bf02 Mon Sep 17 00:00:00 2001 From: Yosuke Higuchi Date: Thu, 19 May 2022 05:03:29 +0900 Subject: [PATCH 1/4] fix bug for returning intermediate states --- espnet/nets/pytorch_backend/e2e_asr_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/espnet/nets/pytorch_backend/e2e_asr_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_transformer.py index b13c7e452b6..07a3203d38b 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_transformer.py @@ -190,7 +190,10 @@ def forward(self, xs_pad, ilens, ys_pad): # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) - hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + if self.intermediate_ctc_layers: + hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + else: + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder From 9c83ddb46404334914764a8e4356ea8a4c3c806c Mon Sep 17 00:00:00 2001 From: Yosuke Higuchi Date: Thu, 19 May 2022 05:05:01 +0900 Subject: [PATCH 2/4] support gpu decoding for mask-ctc --- espnet2/asr/maskctc_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/espnet2/asr/maskctc_model.py b/espnet2/asr/maskctc_model.py index 2a95eec89ea..75c8e588ad7 100644 --- a/espnet2/asr/maskctc_model.py +++ b/espnet2/asr/maskctc_model.py @@ -314,7 +314,10 @@ def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]: confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) mask_num = len(mask_idx) - y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token + y_in = ( + torch.zeros(1, len(y_idx), dtype=torch.long).to(enc_out.device) + + self.mask_token + ) y_in[0][confident_idx] = y_hat[y_idx][confident_idx] logging.info("msk:{}".format(self.ids2text(y_in[0].tolist()))) From 4de7aa562f74c596e5b616fd8278a50a707d0198 Mon Sep 17 00:00:00 2001 From: Yosuke Higuchi Date: Thu, 19 May 2022 06:19:20 +0900 Subject: [PATCH 3/4] fix for test --- espnet/nets/pytorch_backend/e2e_asr_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/espnet/nets/pytorch_backend/e2e_asr_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_transformer.py index 07a3203d38b..c14bf5cfd7a 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_transformer.py @@ -97,7 +97,7 @@ def __init__(self, idim, odim, args, ignore_id=-1): self.ctc = None self.intermediate_ctc_weight = args.intermediate_ctc_weight - self.intermediate_ctc_layers = [] + self.intermediate_ctc_layers = None if args.intermediate_ctc_layer != "": self.intermediate_ctc_layers = [ int(i) for i in args.intermediate_ctc_layer.split(",") @@ -295,7 +295,7 @@ def encode(self, x): """ self.eval() x = torch.as_tensor(x).unsqueeze(0) - enc_output, _, _ = self.encoder(x, None) + enc_output, *_ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): From 88465607cf5e899b8ce1b93c5c9fe09b69a2ab83 Mon Sep 17 00:00:00 2001 From: Yosuke Higuchi Date: Thu, 19 May 2022 07:05:29 +0900 Subject: [PATCH 4/4] fix for test --- espnet/nets/pytorch_backend/e2e_asr_maskctc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/espnet/nets/pytorch_backend/e2e_asr_maskctc.py b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py index 7e7f6c3312d..54f640b6b6d 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_maskctc.py +++ b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py @@ -78,7 +78,7 @@ def __init__(self, idim, odim, args, ignore_id=-1): self.odim = odim self.intermediate_ctc_weight = args.intermediate_ctc_weight - self.intermediate_ctc_layers = [] + self.intermediate_ctc_layers = None if args.intermediate_ctc_layer != "": self.intermediate_ctc_layers = [ int(i) for i in args.intermediate_ctc_layer.split(",") @@ -124,7 +124,10 @@ def forward(self, xs_pad, ilens, ys_pad): # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) - hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + if self.intermediate_ctc_layers: + hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + else: + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder