Skip to content

Commit

Permalink
Merge pull request espnet#4374 from YosukeHiguchi/master
Browse files Browse the repository at this point in the history
Minor fixes for the intermediate loss usage and Mask-CTC decoding
  • Loading branch information
kan-bayashi authored May 19, 2022
2 parents 9ca49ca + 8846560 commit c54b585
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
7 changes: 5 additions & 2 deletions espnet/nets/pytorch_backend/e2e_asr_maskctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, idim, odim, args, ignore_id=-1):
self.odim = odim

self.intermediate_ctc_weight = args.intermediate_ctc_weight
self.intermediate_ctc_layers = []
self.intermediate_ctc_layers = None
if args.intermediate_ctc_layer != "":
self.intermediate_ctc_layers = [
int(i) for i in args.intermediate_ctc_layer.split(",")
Expand Down Expand Up @@ -123,7 +123,10 @@ def forward(self, xs_pad, ilens, ys_pad):
# 1. forward encoder
xs_pad = xs_pad[:, : max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
if self.intermediate_ctc_layers:
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
else:
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad

# 2. forward decoder
Expand Down
9 changes: 6 additions & 3 deletions espnet/nets/pytorch_backend/e2e_asr_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, idim, odim, args, ignore_id=-1):
self.ctc = None

self.intermediate_ctc_weight = args.intermediate_ctc_weight
self.intermediate_ctc_layers = []
self.intermediate_ctc_layers = None
if args.intermediate_ctc_layer != "":
self.intermediate_ctc_layers = [
int(i) for i in args.intermediate_ctc_layer.split(",")
Expand Down Expand Up @@ -189,7 +189,10 @@ def forward(self, xs_pad, ilens, ys_pad):
# 1. forward encoder
xs_pad = xs_pad[:, : max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
if self.intermediate_ctc_layers:
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
else:
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad

# 2. forward decoder
Expand Down Expand Up @@ -291,7 +294,7 @@ def encode(self, x):
"""
self.eval()
x = torch.as_tensor(x).unsqueeze(0)
enc_output, _, _ = self.encoder(x, None)
enc_output, *_ = self.encoder(x, None)
return enc_output.squeeze(0)

def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False):
Expand Down
5 changes: 4 additions & 1 deletion espnet2/asr/maskctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]:
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)

y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
y_in = (
torch.zeros(1, len(y_idx), dtype=torch.long).to(enc_out.device)
+ self.mask_token
)
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]

logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
Expand Down

0 comments on commit c54b585

Please sign in to comment.