diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index 3617bc369c..2deab04b96 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -22,8 +22,76 @@ ls -lh $repo/test_wavs/*.wav pushd $repo/exp ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt +ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt popd +log "Test exporting to ONNX format" + +./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --onnx 1 + +log "Export to torchscript model" +./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit-trace 1 + +ls -lh $repo/exp/*.onnx +ls -lh $repo/exp/*.pt + +log "Decode with ONNX models" + +./pruned_transducer_stateless3/onnx_check.py \ + --jit-filename $repo/exp/cpu_jit.pt \ + --onnx-encoder-filename $repo/exp/encoder.onnx \ + --onnx-decoder-filename $repo/exp/decoder.onnx \ + --onnx-joiner-filename $repo/exp/joiner.onnx + +./pruned_transducer_stateless3/onnx_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decode with models exported by torch.jit.trace()" + +./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_script.pt \ + --decoder-model-filename $repo/exp/decoder_jit_script.pt \ + --joiner-model-filename $repo/exp/joiner_jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do log "Greedy search with --max-sym-per-frame $sym" diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 47976fc2c5..3b6e11a31f 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -35,7 +35,7 @@ on: jobs: run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index fb81238385..e95360d1dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -155,7 +155,8 @@ def forward( # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) @@ -787,6 +788,14 @@ def __init__( ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if torch.jit.is_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + self.d_model = d_model self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None @@ -992,7 +1001,7 @@ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: """Compute relative positional encoding. Args: - x: Input tensor (batch, head, time1, 2*time1-1). + x: Input tensor (batch, head, time1, 2*time1-1+left_context). time1 means the length of query vector. left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, @@ -1006,20 +1015,32 @@ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" - - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) + if not torch.jit.is_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" + + if torch.jit.is_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time2) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, time1, time2) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) def multi_head_attention_forward( self, @@ -1090,13 +1111,15 @@ def multi_head_attention_forward( """ tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + if not torch.jit.is_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" + if not torch.jit.is_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 @@ -1209,7 +1232,7 @@ def multi_head_attention_forward( src_len = k.size(0) - if key_padding_mask is not None: + if key_padding_mask is not None and not torch.jit.is_tracing(): assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) @@ -1220,7 +1243,9 @@ def multi_head_attention_forward( q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 + if not torch.jit.is_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) p = p.permute(0, 2, 3, 1) @@ -1255,11 +1280,12 @@ def multi_head_attention_forward( bsz * num_heads, tgt_len, -1 ) - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] + if not torch.jit.is_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] if attn_mask is not None: if attn_mask.dtype == torch.bool: @@ -1318,7 +1344,14 @@ def multi_head_attention_forward( ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + + if not torch.jit.is_tracing(): + assert list(attn_output.size()) == [ + bsz * num_heads, + tgt_len, + head_dim, + ] + attn_output = ( attn_output.transpose(0, 1) .contiguous() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 1ddfce0345..bd0df5d497 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -77,7 +79,9 @@ def __init__( # It is to support torch script self.conv = nn.Identity() - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + def forward( + self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True + ) -> torch.Tensor: """ Args: y: @@ -88,18 +92,24 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: Returns: Return a tensor of shape (N, U, decoder_dim). """ + if isinstance(need_pad, torch.Tensor): + # This is for torch.jit.trace(), which cannot handle the case + # when the input argument is not a tensor. + need_pad = bool(need_pad) + y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: + if need_pad: embedding_out = F.pad( embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embedding_out.size(-1) == self.context_size + if not torch.jit.is_tracing(): + assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) embedding_out = F.relu(embedding_out) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b916addf0f..b2d6ed0f27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -52,10 +52,10 @@ def forward( Returns: Return a tensor of shape (N, T, s_range, C). """ - - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape + if not torch.jit.is_tracing(): + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) + assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 26a8cca440..2b44dc649e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -152,7 +152,8 @@ def __init__( self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels + if not torch.jit.is_tracing(): + assert x.shape[self.channel_dim] == self.num_channels scales = ( torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.eps.exp() @@ -423,7 +424,7 @@ def __init__( self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: return ActivationBalancerFunction.apply( @@ -472,7 +473,7 @@ def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) else: return DoubleSwishFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index 53ea306ff9..1485c6d6a4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -19,14 +19,67 @@ # This script converts several saved checkpoints # to a single one using model averaging. """ + Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +It will also generate 3 other files: `encoder_jit_script.pt`, +`decoder_jit_script.pt`, and `joiner_jit_script.pt`. + +(2) Export to torchscript model using torch.jit.trace() + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +It will generates 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + + +(3) Export to ONNX format + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following three files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + + +(4) Export `model.state_dict()` + ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 -It will generate a file exp_dir/pretrained.pt +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. To use the generated file with `pruned_transducer_stateless3/decode.py`, you can do: @@ -42,6 +95,20 @@ --max-duration 600 \ --decoding-method greedy_search \ --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp """ import argparse @@ -50,6 +117,8 @@ import sentencepiece as spm import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -114,6 +183,42 @@ def get_parser(): type=str2bool, default=False, help="""True to save a model after applying torch.jit.script. + It will generate 4 files: + - encoder_jit_script.pt + - decoder_jit_script.pt + - joiner_jit_script.pt + - cpu_jit.pt (which combines the above 3 files) + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. Three files will be generated: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + + Check ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) @@ -139,6 +244,275 @@ def get_parser(): return parser +def export_encoder_model_jit_script( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.script() + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(encoder_model) + script_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_script( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.script() + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(decoder_model) + script_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_script( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + """ + script_model = torch.jit.script(joiner_model) + script_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + traced_model = torch.jit.trace(encoder_model, (x, x_lens)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + warmup = 1.0 + torch.onnx.export( + encoder_model, + (x, x_lens, warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "warmup"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and has one output: + + - joiner_out: a tensor of shape (N, vocab_size) + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + project_input = True + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out", "project_input"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -165,7 +539,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_transducer_model(params, enable_giga=False) model.to(device) @@ -185,7 +559,9 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -196,14 +572,39 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) model.to("cpu") model.eval() - - if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) + + if params.onnx is True: + opset_version = 11 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.jit is True: + logging.info("Using torch.jit.script()") # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not @@ -214,8 +615,29 @@ def main(): filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) logging.info(f"Saved to {filename}") + + # Also export encoder/decoder/joiner separately + encoder_filename = params.exp_dir / "encoder_jit_script.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_script.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_script.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + elif params.jit_trace is True: + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) else: - logging.info("Not using torch.jit.script") + logging.info("Not using torchscript") # Save it using a format so that it can be loaded # by :func:`load_checkpoint` filename = params.exp_dir / "pretrained.pt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py new file mode 100755 index 0000000000..162f8c7dbe --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +or + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless3/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav + +or + +./pruned_transducer_stateless3/jit_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py new file mode 100755 index 0000000000..3da31b7ceb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script checks that exported onnx models produce the same output +with the given torchscript model for the same input. +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + +ort.set_default_logger_severity(3) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit-filename", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx-encoder-filename", + required=True, + type=str, + help="Path to the onnx encoder model", + ) + + parser.add_argument( + "--onnx-decoder-filename", + required=True, + type=str, + help="Path to the onnx decoder model", + ) + + parser.add_argument( + "--onnx-joiner-filename", + required=True, + type=str, + help="Path to the onnx joiner model", + ) + + return parser + + +def test_encoder( + model: torch.jit.ScriptModule, + encoder_session: ort.InferenceSession, +): + encoder_inputs = encoder_session.get_inputs() + assert encoder_inputs[0].name == "x" + assert encoder_inputs[1].name == "x_lens" + assert encoder_inputs[0].shape == ["N", "T", 80] + assert encoder_inputs[1].shape == ["N"] + + for N in [1, 5]: + for T in [12, 25]: + print("N, T", N, T) + x = torch.rand(N, T, 80, dtype=torch.float32) + x_lens = torch.randint(low=10, high=T + 1, size=(N,)) + x_lens[0] = T + + encoder_inputs = { + "x": x.numpy(), + "x_lens": x_lens.numpy(), + } + encoder_out, encoder_out_lens = encoder_session.run( + ["encoder_out", "encoder_out_lens"], + encoder_inputs, + ) + + torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) + + encoder_out = torch.from_numpy(encoder_out) + assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( + (encoder_out - torch_encoder_out).abs().max() + ) + + +def test_decoder( + model: torch.jit.ScriptModule, + decoder_session: ort.InferenceSession, +): + decoder_inputs = decoder_session.get_inputs() + assert decoder_inputs[0].name == "y" + assert decoder_inputs[0].shape == ["N", 2] + for N in [1, 5, 10]: + y = torch.randint(low=1, high=500, size=(10, 2)) + + decoder_inputs = {"y": y.numpy()} + decoder_out = decoder_session.run( + ["decoder_out"], + decoder_inputs, + )[0] + decoder_out = torch.from_numpy(decoder_out) + + torch_decoder_out = model.decoder(y, need_pad=False) + assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( + (decoder_out - torch_decoder_out).abs().max() + ) + + +def test_joiner( + model: torch.jit.ScriptModule, + joiner_session: ort.InferenceSession, +): + joiner_inputs = joiner_session.get_inputs() + assert joiner_inputs[0].name == "encoder_out" + assert joiner_inputs[0].shape == ["N", 512] + + assert joiner_inputs[1].name == "decoder_out" + assert joiner_inputs[1].shape == ["N", 512] + + for N in [1, 5, 10]: + encoder_out = torch.rand(N, 512) + decoder_out = torch.rand(N, 512) + + joiner_inputs = { + "encoder_out": encoder_out.numpy(), + "decoder_out": decoder_out.numpy(), + } + joiner_out = joiner_session.run(["logit"], joiner_inputs)[0] + joiner_out = torch.from_numpy(joiner_out) + + torch_joiner_out = model.joiner( + encoder_out, + decoder_out, + project_input=True, + ) + assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( + (joiner_out - torch_joiner_out).abs().max() + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + model = torch.jit.load(args.jit_filename) + + options = ort.SessionOptions() + options.inter_op_num_threads = 1 + options.intra_op_num_threads = 1 + + logging.info("Test encoder") + encoder_session = ort.InferenceSession( + args.onnx_encoder_filename, + sess_options=options, + ) + test_encoder(model, encoder_session) + + logging.info("Test decoder") + decoder_session = ort.InferenceSession( + args.onnx_decoder_filename, + sess_options=options, + ) + test_decoder(model, decoder_session) + + logging.info("Test joiner") + joiner_session = ort.InferenceSession( + args.onnx_joiner_filename, + sess_options=options, + ) + test_joiner(model, joiner_session) + logging.info("Finished checking ONNX models") + + +if __name__ == "__main__": + torch.manual_seed(20220727) + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py new file mode 100755 index 0000000000..ebfae9d5f0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./pruned_transducer_stateless3/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \ + --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \ + --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import numpy as np +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: ort.InferenceSession, + joiner: ort.InferenceSession, + encoder_out: np.ndarray, + encoder_out_lens: np.ndarray, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + encoder_out = torch.from_numpy(encoder_out) + encoder_out_lens = torch.from_numpy(encoder_out_lens) + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input_nodes = decoder.get_inputs() + decoder_output_nodes = decoder.get_outputs() + + joiner_input_nodes = joiner.get_inputs() + joiner_output_nodes = joiner.get_outputs() + + decoder_input = torch.tensor( + hyps, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: current_encoder_out.numpy(), + joiner_input_nodes[1].name: decoder_out, + }, + )[0] + logits = torch.from_numpy(logits) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + dtype=torch.int64, + ) + decoder_out = decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0].squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=session_opts, + ) + + decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=session_opts, + ) + + joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=session_opts, + ) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) + + encoder_input_nodes = encoder.get_inputs() + encoder_out_nodes = encoder.get_outputs() + encoder_out, encoder_out_lens = encoder.run( + [encoder_out_nodes[0].name, encoder_out_nodes[1].name], + { + encoder_input_nodes[0].name: features.numpy(), + encoder_input_nodes[1].name: feature_lengths.numpy(), + }, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 8b0389bc9c..c15d65ded4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -15,7 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Usage: +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py new file mode 100644 index 0000000000..c810e36e63 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -0,0 +1,189 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, +and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`, +and `nn.Conv2d`. + +The scaled version are required only in the training time. It simplifies our +life by converting them their non-scaled version during inference time. +""" + +import copy +import re + +import torch +import torch.nn as nn +from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear + + +def _get_weight(self: torch.nn.Linear): + return self.weight + + +def _get_bias(self: torch.nn.Linear): + return self.bias + + +def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: + """Convert an instance of ScaledLinear to nn.Linear. + + Args: + scaled_linear: + The layer to be converted. + Returns: + Return a linear layer. It satisfies: + + scaled_linear(x) == linear(x) + + for any given input tensor `x`. + """ + assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) + + # if not hasattr(torch.nn.Linear, "get_weight"): + # torch.nn.Linear.get_weight = _get_weight + # torch.nn.Linear.get_bias = _get_bias + + weight = scaled_linear.get_weight() + bias = scaled_linear.get_bias() + has_bias = bias is not None + + linear = torch.nn.Linear( + in_features=scaled_linear.in_features, + out_features=scaled_linear.out_features, + bias=True, # otherwise, it throws errors when converting to PNNX format. + device=weight.device, + ) + linear.weight.data.copy_(weight) + + if has_bias: + linear.bias.data.copy_(bias) + else: + linear.bias.data.zero_() + + return linear + + +def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d: + """Convert an instance of ScaledConv1d to nn.Conv1d. + + Args: + scaled_conv1d: + The layer to be converted. + Returns: + Return an instance of nn.Conv1d that has the same `forward()` behavior + of the given `scaled_conv1d`. + """ + assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d) + + weight = scaled_conv1d.get_weight() + bias = scaled_conv1d.get_bias() + has_bias = bias is not None + + conv1d = nn.Conv1d( + in_channels=scaled_conv1d.in_channels, + out_channels=scaled_conv1d.out_channels, + kernel_size=scaled_conv1d.kernel_size, + stride=scaled_conv1d.stride, + padding=scaled_conv1d.padding, + dilation=scaled_conv1d.dilation, + groups=scaled_conv1d.groups, + bias=scaled_conv1d.bias is not None, + padding_mode=scaled_conv1d.padding_mode, + ) + + conv1d.weight.data.copy_(weight) + if has_bias: + conv1d.bias.data.copy_(bias) + + return conv1d + + +def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: + """Convert an instance of ScaledConv2d to nn.Conv2d. + + Args: + scaled_conv2d: + The layer to be converted. + Returns: + Return an instance of nn.Conv2d that has the same `forward()` behavior + of the given `scaled_conv2d`. + """ + assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d) + + weight = scaled_conv2d.get_weight() + bias = scaled_conv2d.get_bias() + has_bias = bias is not None + + conv2d = nn.Conv2d( + in_channels=scaled_conv2d.in_channels, + out_channels=scaled_conv2d.out_channels, + kernel_size=scaled_conv2d.kernel_size, + stride=scaled_conv2d.stride, + padding=scaled_conv2d.padding, + dilation=scaled_conv2d.dilation, + groups=scaled_conv2d.groups, + bias=scaled_conv2d.bias is not None, + padding_mode=scaled_conv2d.padding_mode, + ) + + conv2d.weight.data.copy_(weight) + if has_bias: + conv2d.bias.data.copy_(bias) + + return conv2d + + +def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): + """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` + in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, + and `nn.Conv2d`. + + Args: + model: + The model to be converted. + inplace: + If True, the input model is modified inplace. + If False, the input model is copied and we modify the copied version. + Return: + Return a model without scaled layers. + """ + if not inplace: + model = copy.deepcopy(model) + + excluded_patterns = r"self_attn\.(in|out)_proj" + p = re.compile(excluded_patterns) + + d = {} + for name, m in model.named_modules(): + if isinstance(m, ScaledLinear): + if p.search(name) is not None: + continue + d[name] = scaled_linear_to_linear(m) + elif isinstance(m, ScaledConv1d): + d[name] = scaled_conv1d_to_conv1d(m) + elif isinstance(m, ScaledConv2d): + d[name] = scaled_conv2d_to_conv2d(m) + + for k, v in d.items(): + if "." in k: + parent, child = k.rsplit(".", maxsplit=1) + setattr(model.get_submodule(parent), child, v) + else: + setattr(model, k, v) + + return model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py new file mode 100644 index 0000000000..34a9c27f71 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless3/test_scaling_converter.py +""" + +import copy + +import torch +from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear +from scaling_converter import ( + convert_scaled_to_non_scaled, + scaled_conv1d_to_conv1d, + scaled_conv2d_to_conv2d, + scaled_linear_to_linear, +) +from train import get_params, get_transducer_model + + +def get_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_transducer_model(params, enable_giga=False) + return model + + +def test_scaled_linear_to_linear(): + N = 5 + in_features = 10 + out_features = 20 + for bias in [True, False]: + scaled_linear = ScaledLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + linear = scaled_linear_to_linear(scaled_linear) + x = torch.rand(N, in_features) + + y1 = scaled_linear(x) + y2 = linear(x) + assert torch.allclose(y1, y2) + + jit_scaled_linear = torch.jit.script(scaled_linear) + jit_linear = torch.jit.script(linear) + + y3 = jit_scaled_linear(x) + y4 = jit_linear(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv1d_to_conv1d(): + in_channels = 3 + for bias in [True, False]: + scaled_conv1d = ScaledConv1d( + in_channels, + 6, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + conv1d = scaled_conv1d_to_conv1d(scaled_conv1d) + + x = torch.rand(20, in_channels, 10) + y1 = scaled_conv1d(x) + y2 = conv1d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv1d = torch.jit.script(scaled_conv1d) + jit_conv1d = torch.jit.script(conv1d) + + y3 = jit_scaled_conv1d(x) + y4 = jit_conv1d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_scaled_conv2d_to_conv2d(): + in_channels = 1 + for bias in [True, False]: + scaled_conv2d = ScaledConv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + padding=1, + bias=bias, + ) + + conv2d = scaled_conv2d_to_conv2d(scaled_conv2d) + + x = torch.rand(20, in_channels, 10, 20) + y1 = scaled_conv2d(x) + y2 = conv2d(x) + assert torch.allclose(y1, y2) + + jit_scaled_conv2d = torch.jit.script(scaled_conv2d) + jit_conv2d = torch.jit.script(conv2d) + + y3 = jit_scaled_conv2d(x) + y4 = jit_conv2d(x) + + assert torch.allclose(y3, y4) + assert torch.allclose(y1, y4) + + +def test_convert_scaled_to_non_scaled(): + for inplace in [False, True]: + model = get_model() + model.eval() + + orig_model = copy.deepcopy(model) + + converted_model = convert_scaled_to_non_scaled(model, inplace=inplace) + + model = orig_model + + # test encoder + N = 2 + T = 100 + vocab_size = model.decoder.vocab_size + + x = torch.randn(N, T, 80, dtype=torch.float32) + x_lens = torch.full((N,), x.size(1)) + + e1, e1_lens = model.encoder(x, x_lens) + e2, e2_lens = converted_model.encoder(x, x_lens) + + assert torch.all(torch.eq(e1_lens, e2_lens)) + assert torch.allclose(e1, e2), (e1 - e2).abs().max() + + # test decoder + U = 50 + y = torch.randint(low=1, high=vocab_size - 1, size=(N, U)) + + d1 = model.decoder(y) + d2 = model.decoder(y) + + assert torch.allclose(d1, d2) + + # test simple projection + lm1 = model.simple_lm_proj(d1) + am1 = model.simple_am_proj(e1) + + lm2 = converted_model.simple_lm_proj(d2) + am2 = converted_model.simple_am_proj(e2) + + assert torch.allclose(lm1, lm2) + assert torch.allclose(am1, am2) + + # test joiner + e = torch.rand(2, 3, 4, 512) + d = torch.rand(2, 3, 4, 512) + + j1 = model.joiner(e, d) + j2 = converted_model.joiner(e, d) + assert torch.allclose(j1, j2) + + +@torch.no_grad() +def main(): + test_scaled_linear_to_linear() + test_scaled_conv1d_to_conv1d() + test_scaled_conv2d_to_conv2d() + test_convert_scaled_to_non_scaled() + + +if __name__ == "__main__": + torch.manual_seed(20220730) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 371bf21d91..ff9b8d8087 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -436,13 +436,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner -def get_transducer_model(params: AttributeDict) -> nn.Module: +def get_transducer_model( + params: AttributeDict, + enable_giga: bool = True, +) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) - decoder_giga = get_decoder_model(params) - joiner_giga = get_joiner_model(params) + if enable_giga: + logging.info("Use giga") + decoder_giga = get_decoder_model(params) + joiner_giga = get_joiner_model(params) + else: + logging.info("Disable giga") + decoder_giga = None + joiner_giga = None model = Transducer( encoder=encoder, diff --git a/requirements-ci.txt b/requirements-ci.txt index fc17b123a8..48769d61a2 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -20,3 +20,6 @@ sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 multi_quantization + +onnx +onnxruntime diff --git a/requirements.txt b/requirements.txt index 90b1dac698..25b5529f0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ sentencepiece>=0.1.96 tensorboard typeguard multi_quantization +onnx +onnxruntime