Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support exporting to ONNX format #501

Merged
merged 10 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,76 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
popd

log "Test exporting to ONNX format"

./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1

log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1

./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit-trace 1

ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt

log "Decode with ONNX models"

./pruned_transducer_stateless3/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx

./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

log "Decode with models exported by torch.jit.trace()"

./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

log "Decode with models exported by torch.jit.script()"

./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav


for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ on:

jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down
91 changes: 62 additions & 29 deletions egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def forward(
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1

assert x.size(0) == lengths.max().item()
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()

src_key_padding_mask = make_pad_mask(lengths)

Expand Down Expand Up @@ -787,6 +788,14 @@ def __init__(
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if torch.jit.is_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000

self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
Expand Down Expand Up @@ -992,7 +1001,7 @@ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.

Args:
x: Input tensor (batch, head, time1, 2*time1-1).
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
Expand All @@ -1006,20 +1015,32 @@ def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
(batch_size, num_heads, time1, n) = x.shape

time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"

# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
if not torch.jit.is_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"

if torch.jit.is_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols

x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)

def multi_head_attention_forward(
self,
Expand Down Expand Up @@ -1090,13 +1111,15 @@ def multi_head_attention_forward(
"""

tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
if not torch.jit.is_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
if not torch.jit.is_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"

scaling = float(head_dim) ** -0.5

Expand Down Expand Up @@ -1209,7 +1232,7 @@ def multi_head_attention_forward(

src_len = k.size(0)

if key_padding_mask is not None:
if key_padding_mask is not None and not torch.jit.is_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
Expand All @@ -1220,7 +1243,9 @@ def multi_head_attention_forward(
q = q.transpose(0, 1) # (batch, time1, head, d_k)

pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
if not torch.jit.is_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1

p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
Expand Down Expand Up @@ -1255,11 +1280,12 @@ def multi_head_attention_forward(
bsz * num_heads, tgt_len, -1
)

assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if not torch.jit.is_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
Expand Down Expand Up @@ -1318,7 +1344,14 @@ def multi_head_attention_forward(
)

attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]

if not torch.jit.is_tracing():
assert list(attn_output.size()) == [
bsz * num_heads,
tgt_len,
head_dim,
]

attn_output = (
attn_output.transpose(0, 1)
.contiguous()
Expand Down
16 changes: 13 additions & 3 deletions egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -77,7 +79,9 @@ def __init__(
# It is to support torch script
self.conv = nn.Identity()

def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
def forward(
self, y: torch.Tensor, need_pad: Union[bool, torch.Tensor] = True
) -> torch.Tensor:
"""
Args:
y:
Expand All @@ -88,18 +92,24 @@ def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
if isinstance(need_pad, torch.Tensor):
# This is for torch.jit.trace(), which cannot handle the case
# when the input argument is not a tensor.
need_pad = bool(need_pad)

y = y.to(torch.int64)
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
if need_pad:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
if not torch.jit.is_tracing():
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
Expand Down
8 changes: 4 additions & 4 deletions egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def forward(
Returns:
Return a tensor of shape (N, T, s_range, C).
"""

assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if not torch.jit.is_tracing():
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape

if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
Expand Down
7 changes: 4 additions & 3 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def __init__(
self.register_buffer("eps", torch.tensor(eps).log().detach())

def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
Expand Down Expand Up @@ -423,7 +424,7 @@ def __init__(
self.max_abs = max_abs

def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x
else:
return ActivationBalancerFunction.apply(
Expand Down Expand Up @@ -472,7 +473,7 @@ def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)
Expand Down
Loading