Skip to content

Commit

Permalink
Fix cache aware hybrid bugs (#6466) (#6484)
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] authored Apr 25, 2023
1 parent 45d619c commit 2b2c9f5
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@
You may drop the '--debug_mode' and '--compare_vs_offline' to speedup the streaming evaluation.
If compare_vs_offline is not used, then significantly larger batch_size can be used.
Setting `--pad_and_drop_preencoded` would perform the caching for all steps including the first step.
It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding.
Enabling it would make it easier to export the model to ONNX.
# Hybrid ASR models
For Hybrid ASR models which have two decoders, you may select the decoder by --set_decoder DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt".
If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models.
To best compare output with offline output (i.e. `--compare_vs_offline` is set) `--pad-and-drop-preencoded` should also be set.
## Evaluate a model trained with full context for offline mode
Expand Down Expand Up @@ -126,6 +132,7 @@ def perform_streaming(
transcribed_texts,
cache_last_channel_next,
cache_last_time_next,
cache_last_channel_len,
best_hyp,
) = asr_model.conformer_stream_step(
processed_signal=processed_signal,
Expand Down Expand Up @@ -254,9 +261,16 @@ def main():
"--output_path", type=str, help="path to output file when manifest is used as input", default=None
)
parser.add_argument(
"--pad-and-drop-preencoded",
"--pad_and_drop_preencoded",
action="store_true",
help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for the first step. It makes the outputs of the downsampling exactly as the offline mode for some techniques like striding.",
help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for all the steps including the the first step. It may make the outputs of the downsampling slightly different from offline mode for some techniques like striding or sw_striding.",
)

parser.add_argument(
"--set_decoder",
choices=["ctc", "rnnt"],
default=None,
help="Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. Supported decoders are ['ctc', 'rnnt']",
)

args = parser.parse_args()
Expand All @@ -273,6 +287,11 @@ def main():
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.asr_model)

logging.info(asr_model.encoder.streaming_cfg)
if args.set_decoder is not None:
if hasattr(asr_model, "cur_decoder"):
asr_model.change_decoding_strategy(decoder_type=args.set_decoder)
else:
raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")

global autocast
if (
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
)

# setting the RNNT decoder as the default one
self.use_rnnt_decoder = True
self.cur_decoder = "rnnt"

def _setup_dataloader_from_config(self, config: Optional[Dict]):
dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config(
Expand Down Expand Up @@ -375,7 +375,7 @@ def change_vocabulary(

logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.")

def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None):
def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None):
"""
Changes decoding strategy used during RNNT decoding process.
Args:
Expand Down Expand Up @@ -446,7 +446,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str =
with open_dict(self.cfg.aux_ctc.decoding):
self.cfg.aux_ctc.decoding = decoding_cfg

self.use_rnnt_decoder = False
self.cur_decoder = "ctc"
logging.info(
f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}"
)
Expand Down
14 changes: 9 additions & 5 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
)

# setting the RNNT decoder as the default one
self.use_rnnt_decoder = True
self.cur_decoder = "rnnt"

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc(decoder_name='ctc_decoder', loss_name='ctc_loss', wer_name='ctc_wer')
Expand Down Expand Up @@ -125,7 +125,11 @@ def transcribe(
* A list of greedy transcript texts / Hypothesis
* An optional list of beam search transcript texts / Hypothesis / NBestHypothesis.
"""
if self.use_rnnt_decoder:
if self.cur_decoder not in ["ctc", "rnnt"]:
raise ValueError(
f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']"
)
if self.cur_decoder == "rnnt":
return super().transcribe(
paths2audio_files=paths2audio_files,
batch_size=batch_size,
Expand Down Expand Up @@ -307,7 +311,7 @@ def change_vocabulary(

logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.")

def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None):
def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None):
"""
Changes decoding strategy used during RNNT decoding process.
Expand All @@ -319,7 +323,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str =
used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model.
"""
if decoder_type is None or decoder_type == 'rnnt':
self.use_rnnt_decoder = True
self.cur_decoder = "rnnt"
return super().change_decoding_strategy(decoding_cfg=decoding_cfg)

assert decoder_type == 'ctc' and hasattr(self, 'ctc_decoder')
Expand All @@ -346,7 +350,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str =
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = decoding_cfg

self.use_rnnt_decoder = False
self.cur_decoder = "ctc"
logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}")

# PTL-specific methods
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def forward_for_export(

def streaming_post_process(self, rets, keep_all_outputs=True):
if len(rets) == 2:
return rets
return rets[0], rets[1], None, None, None

(encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets

Expand Down
15 changes: 12 additions & 3 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,17 @@ def conformer_stream_step(
drop_extra_pre_encoded=drop_extra_pre_encoded,
)

if isinstance(self, asr_models.EncDecCTCModel):
log_probs = self.decoder(encoder_output=encoded)
if isinstance(self, asr_models.EncDecCTCModel) or (
isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc"
):
if hasattr(self, "ctc_decoder"):
decoding = self.ctc_decoding
decoder = self.ctc_decoder
else:
decoding = self.decoding
decoder = self.decoder

log_probs = decoder(encoder_output=encoded)
predictions_tensor = log_probs.argmax(dim=-1, keepdim=False)

# Concatenate the previous predictions with the current one to have the full predictions.
Expand All @@ -517,7 +526,7 @@ def conformer_stream_step(

# TODO: make decoding more efficient by avoiding the decoding process from the beginning
if return_transcription:
decoded_out = self.decoding.ctc_decoder_predictions_tensor(
decoded_out = decoding.ctc_decoder_predictions_tensor(
decoder_outputs=greedy_predictions_concat.unsqueeze(0),
decoder_lengths=encoded_len[preds_idx : preds_idx + 1],
return_hypotheses=False,
Expand Down
74 changes: 52 additions & 22 deletions nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,42 +126,72 @@ def __init__(
self._kernel_size = 3
self._ceil_mode = False

self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
if self.is_causal:
self._left_padding = self._kernel_size - 1
self._right_padding = self._stride - 1
self._max_cache_len = subsampling_factor + 1
else:
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
self._max_cache_len = 0

# Layer 1
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=None,
)
)
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
)
)
)
in_channels = conv_channels
layers.append(activation)

for i in range(self._sampling_num - 1):
layers.extend(
[
torch.nn.Conv2d(
if self.is_causal:
layers.append(
CausalConv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
padding=None,
groups=in_channels,
),
)
)
else:
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
),
]
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
)
)

layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
)
)
layers.append(activation)
in_channels = conv_channels
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,9 +1367,10 @@ def __iter__(self):
)

if self.buffer_idx == 0 and isinstance(self.streaming_cfg.shift_size, list):
shift_size = self.streaming_cfg.shift_size[0]
if self.pad_and_drop_preencoded:
shift_size = self.streaming_cfg.shift_size[1]
else:
shift_size = self.streaming_cfg.shift_size[0]
else:
shift_size = (
self.streaming_cfg.shift_size[1]
Expand All @@ -1394,9 +1395,10 @@ def __iter__(self):
# if there is not enough frames to be used as the pre-encoding cache, zeros would be added
zeros_pads = None
if self.buffer_idx == 0 and isinstance(self.streaming_cfg.pre_encode_cache_size, list):
cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[0]
if self.pad_and_drop_preencoded:
cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[1]
else:
cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[0]
cache_pre_encode = torch.zeros(
(audio_chunk.size(0), self.input_features, cache_pre_encode_num_frames),
device=audio_chunk.device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,4 @@ def test_decoding_change(self, hybrid_asr_model):

assert hybrid_asr_model.ctc_decoding.preserve_alignments is True
assert hybrid_asr_model.ctc_decoding.compute_timestamps is True
assert hybrid_asr_model.use_rnnt_decoder is False
assert hybrid_asr_model.cur_decoder == "ctc"

0 comments on commit 2b2c9f5

Please sign in to comment.