Skip to content

Commit

Permalink
Merge branch 'espnet:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
roshansh-cmu authored Feb 23, 2022
2 parents 23a537e + 9c24b3a commit 58aec43
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 4 deletions.
3 changes: 3 additions & 0 deletions egs2/librispeech_100/asr1/conf/tuning/decode_ctc_bs1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
lm_weight: 0.0
ctc_weight: 1.0
beam_size: 1
61 changes: 61 additions & 0 deletions egs2/librispeech_100/asr1/conf/tuning/train_conformer_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
batch_type: numel
batch_bins: 2000000
accum_grad: 16
max_epoch: 60
patience: none
init: none
best_model_criterion:
- - valid
- cer_ctc
- min
keep_nbest_models: 10

encoder: conformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 18
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
macaron_style: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
use_cnn_module: true
cnn_module_kernel: 31

model_conf:
ctc_weight: 1.0
lsm_weight: 0.1
length_normalized_loss: false

optim: adam
optim_conf:
lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000

num_att_plot: 0

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
batch_type: numel
batch_bins: 2000000
accum_grad: 16
max_epoch: 60
patience: none
init: none
best_model_criterion:
- - valid
- cer_ctc
- min
keep_nbest_models: 10

encoder: conformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 18
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
macaron_style: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
use_cnn_module: true
cnn_module_kernel: 31
interctc_layer_idx: [6,12]

model_conf:
ctc_weight: 1.0
interctc_weight: 0.66
lsm_weight: 0.1
length_normalized_loss: false

optim: adam
optim_conf:
lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000

num_att_plot: 0

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
64 changes: 64 additions & 0 deletions egs2/librispeech_100/asr1/conf/tuning/train_conformer_scctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
batch_type: numel
batch_bins: 2000000
accum_grad: 16
max_epoch: 60
patience: none
init: none
best_model_criterion:
- - valid
- cer_ctc
- min
keep_nbest_models: 10

encoder: conformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 18
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
macaron_style: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
use_cnn_module: true
cnn_module_kernel: 31
interctc_layer_idx: [6,12]
interctc_use_conditioning: true

model_conf:
ctc_weight: 1.0
interctc_weight: 0.66
lsm_weight: 0.1
length_normalized_loss: false

optim: adam
optim_conf:
lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 15000

num_att_plot: 0

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
10 changes: 10 additions & 0 deletions espnet2/asr/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ def forward(self, hs_pad, hlens, ys_pad, ys_lens):

return loss

def softmax(self, hs_pad):
"""softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
return F.softmax(self.ctc_lo(hs_pad), dim=2)

def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Expand Down
43 changes: 42 additions & 1 deletion espnet2/asr/encoder/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Conformer encoder definition."""

from typing import List
from typing import Optional
from typing import Tuple

Expand Down Expand Up @@ -39,6 +40,7 @@
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8
from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError
from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder


Expand Down Expand Up @@ -101,6 +103,8 @@ def __init__(
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
Expand Down Expand Up @@ -262,6 +266,12 @@ def __init__(
if self.normalize_before:
self.after_norm = LayerNorm(output_size)

self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None

def output_size(self) -> int:
return self._output_size

Expand All @@ -270,6 +280,7 @@ def forward(
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Expand Down Expand Up @@ -303,11 +314,41 @@ def forward(
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)

intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)

if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]

# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)

intermediate_outs.append((layer_idx + 1, encoder_out))

if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)

if isinstance(xs_pad, tuple):
x, pos_emb = xs_pad
x = x + self.conditioning_layer(ctc_out)
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)

if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)

olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
38 changes: 36 additions & 2 deletions espnet2/asr/encoder/transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""Encoder definition."""
"""Transformer encoder definition."""

from typing import List
from typing import Optional
from typing import Tuple

Expand All @@ -25,6 +27,7 @@
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8
from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError
from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder


Expand Down Expand Up @@ -70,6 +73,8 @@ def __init__(
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
):
assert check_argument_types()
super().__init__()
Expand Down Expand Up @@ -144,6 +149,12 @@ def __init__(
if self.normalize_before:
self.after_norm = LayerNorm(output_size)

self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None

def output_size(self) -> int:
return self._output_size

Expand All @@ -152,6 +163,7 @@ def forward(
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Expand Down Expand Up @@ -181,9 +193,31 @@ def forward(
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad, masks = self.encoders(xs_pad, masks)

intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)

if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad

# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)

intermediate_outs.append((layer_idx + 1, encoder_out))

if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)

if self.normalize_before:
xs_pad = self.after_norm(xs_pad)

olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
Loading

0 comments on commit 58aec43

Please sign in to comment.