Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pruned transducer stateless2 with streaming conformer for WenetSpeech #449

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

1,553 changes: 1,553 additions & 0 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/conformer.py

Large diffs are not rendered by default.

78 changes: 75 additions & 3 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@
--beam 4 \
--max-contexts 4 \
--max-states 8

(4) decode in a streaming mode (take greedy search as an example)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""


import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand All @@ -68,7 +85,7 @@
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import (
average_checkpoints,
Expand All @@ -80,9 +97,12 @@
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

LOG_EPS = math.log(1e-10)


def get_parser():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -204,6 +224,30 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)

parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)

parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)

parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)

add_model_arguments(parser)
return parser


Expand Down Expand Up @@ -250,9 +294,27 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)

encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)

if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)

hyps = []

if params.decoding_method == "fast_beam_search":
Expand Down Expand Up @@ -459,6 +521,11 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method

params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"

if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"

if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
Expand All @@ -482,6 +549,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1

if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"

logging.info(params)

logging.info("About to create model")
Expand Down
126 changes: 126 additions & 0 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Optional, Tuple

import k2
import torch

from icefall.utils import AttributeDict


class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if decoding_graph is not None:
assert device == decoding_graph.device

self.params = params
self.LOG_EPS = math.log(1e-10)

self.states = initial_states

# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None

self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0

self._done: bool = False

# The transcript of current utterance.
self.ground_truth: str = ""

# The decoding result (partial or final) of current utterance.
self.hyp: List = []

# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.done_frames: int = 0

self.pad_length = (
params.right_context + 2
) * params.subsampling_factor + 3

if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
else:
assert (
False
), f"Decoding method :{params.decoding_method} do not support."

@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done

def set_features(
self,
features: torch.Tensor,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)

def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length

ret_length = min(
self.num_frames - self.num_processed_frames, chunk_length
)

ret_features = self.features[
self.num_processed_frames : self.num_processed_frames # noqa
+ ret_length
]

self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True

return ret_features, ret_length
15 changes: 14 additions & 1 deletion egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pathlib import Path

import torch
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
Expand Down Expand Up @@ -107,6 +107,16 @@ def get_parser():
"2 means tri-gram",
)

parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)

add_model_arguments(parser)
return parser


Expand All @@ -128,6 +138,9 @@ def main():
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1

if params.streaming_model:
assert params.causal_convolution

logging.info(params)

logging.info("About to create model")
Expand Down
1 change: 0 additions & 1 deletion egs/wenetspeech/ASR/pruned_transducer_stateless2/joiner.py

This file was deleted.

69 changes: 69 additions & 0 deletions egs/wenetspeech/ASR/pruned_transducer_stateless2/joiner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from scaling import ScaledLinear


class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()

self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim)
self.output_linear = ScaledLinear(joiner_dim, vocab_size)

def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""

assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape

if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out

logit = self.output_linear(torch.tanh(logit))

return logit
Loading