Skip to content

Commit

Permalink
Merge pull request espnet#4349 from pyf98/quantization
Browse files Browse the repository at this point in the history
Add quantization in ESPnet2 for asr inference
  • Loading branch information
sw005320 authored May 13, 2022
2 parents fffb344 + cd77501 commit afa8f8e
Showing 3 changed files with 94 additions and 2 deletions.
6 changes: 4 additions & 2 deletions espnet/nets/pytorch_backend/rnn/encoders.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,8 @@ def forward(self, xs_pad, ilens, prev_state=None):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if self.training:
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(
@@ -144,7 +145,8 @@ def forward(self, xs_pad, ilens, prev_state=None):
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
self.nbrnn.flatten_parameters()
if self.training:
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed,
# it means that we're streaming the input
72 changes: 72 additions & 0 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import argparse
from distutils.version import LooseVersion
import logging
from pathlib import Path
import sys
@@ -11,6 +12,7 @@

import numpy as np
import torch
import torch.quantization
from typeguard import check_argument_types
from typeguard import check_return_type
from typing import List
@@ -79,11 +81,27 @@ def __init__(
nbest: int = 1,
streaming: bool = False,
enh_s2t_task: bool = False,
quantize_asr_model: bool = False,
quantize_lm: bool = False,
quantize_modules: List[str] = ["Linear"],
quantize_dtype: str = "qint8",
):
assert check_argument_types()

task = ASRTask if not enh_s2t_task else EnhS2TTask

if quantize_asr_model or quantize_lm:
if quantize_dtype == "float16" and torch.__version__ < LooseVersion(
"1.5.0"
):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with "
"torch version < 1.5.0. Switch to qint8 dtype instead."
)

quantize_modules = set([getattr(torch.nn, q) for q in quantize_modules])
quantize_dtype = getattr(torch, quantize_dtype)

# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = task.build_model_from_file(
@@ -103,6 +121,13 @@ def __init__(
)
asr_model.to(dtype=getattr(torch, dtype)).eval()

if quantize_asr_model:
logging.info("Use quantized asr model for decoding.")

asr_model = torch.quantization.quantize_dynamic(
asr_model, qconfig_spec=quantize_modules, dtype=quantize_dtype
)

decoder = asr_model.decoder

ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
@@ -118,6 +143,14 @@ def __init__(
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)

if quantize_lm:
logging.info("Use quantized lm for decoding.")

lm = torch.quantization.quantize_dynamic(
lm, qconfig_spec=quantize_modules, dtype=quantize_dtype
)

scorers["lm"] = lm.lm

# 3. Build ngram model
@@ -365,6 +398,10 @@ def inference(
transducer_conf: Optional[dict],
streaming: bool,
enh_s2t_task: bool,
quantize_asr_model: bool,
quantize_lm: bool,
quantize_modules: List[str],
quantize_dtype: str,
):
assert check_argument_types()
if batch_size > 1:
@@ -409,6 +446,10 @@ def inference(
nbest=nbest,
streaming=streaming,
enh_s2t_task=enh_s2t_task,
quantize_asr_model=quantize_asr_model,
quantize_lm=quantize_lm,
quantize_modules=quantize_modules,
quantize_dtype=quantize_dtype,
)
speech2text = Speech2Text.from_pretrained(
model_tag=model_tag,
@@ -557,6 +598,37 @@ def get_parser():
help="enhancement and asr joint model",
)

group = parser.add_argument_group("Quantization related")
group.add_argument(
"--quantize_asr_model",
type=str2bool,
default=False,
help="Apply dynamic quantization to ASR model.",
)
group.add_argument(
"--quantize_lm",
type=str2bool,
default=False,
help="Apply dynamic quantization to LM.",
)
group.add_argument(
"--quantize_modules",
type=str,
nargs="*",
default=["Linear"],
help="""List of modules to be dynamically quantized.
E.g.: --quantize_modules=[Linear,LSTM,GRU].
Each specified module should be an attribute of 'torch.nn', e.g.:
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
)
group.add_argument(
"--quantize_dtype",
type=str,
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)

group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
18 changes: 18 additions & 0 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
@@ -86,6 +86,24 @@ def test_Speech2Text(asr_config_file, lm_config_file):
assert isinstance(hyp, Hypothesis)


@pytest.mark.execution_timeout(5)
def test_Speech2Text_quantized(asr_config_file, lm_config_file):
speech2text = Speech2Text(
asr_train_config=asr_config_file,
lm_train_config=lm_config_file,
beam_size=1,
quantize_asr_model=True,
quantize_lm=True,
)
speech = np.random.randn(100000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)


@pytest.fixture()
def asr_config_file_streaming(tmp_path: Path, token_list):
# Write default configuration file

0 comments on commit afa8f8e

Please sign in to comment.