Skip to content

Commit

Permalink
fix for test
Browse files Browse the repository at this point in the history
  • Loading branch information
YosukeHiguchi committed May 18, 2022
1 parent 4de7aa5 commit 8846560
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions espnet/nets/pytorch_backend/e2e_asr_maskctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, idim, odim, args, ignore_id=-1):
self.odim = odim

self.intermediate_ctc_weight = args.intermediate_ctc_weight
self.intermediate_ctc_layers = []
self.intermediate_ctc_layers = None
if args.intermediate_ctc_layer != "":
self.intermediate_ctc_layers = [
int(i) for i in args.intermediate_ctc_layer.split(",")
Expand Down Expand Up @@ -124,7 +124,10 @@ def forward(self, xs_pad, ilens, ys_pad):
# 1. forward encoder
xs_pad = xs_pad[:, : max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
if self.intermediate_ctc_layers:
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
else:
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad

# 2. forward decoder
Expand Down

0 comments on commit 8846560

Please sign in to comment.