From 49100e4f1b3fc389c5672dc2ca17973525c4bf02 Mon Sep 17 00:00:00 2001 From: Yosuke Higuchi Date: Thu, 19 May 2022 05:03:29 +0900 Subject: [PATCH] fix bug for returning intermediate states --- espnet/nets/pytorch_backend/e2e_asr_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/espnet/nets/pytorch_backend/e2e_asr_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_transformer.py index b13c7e452b6..07a3203d38b 100644 --- a/espnet/nets/pytorch_backend/e2e_asr_transformer.py +++ b/espnet/nets/pytorch_backend/e2e_asr_transformer.py @@ -190,7 +190,10 @@ def forward(self, xs_pad, ilens, ys_pad): # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) - hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + if self.intermediate_ctc_layers: + hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask) + else: + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder