Skip to content

Commit

Permalink
add quantization to asr_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
pyf98 committed May 10, 2022
1 parent beb3360 commit acb24c8
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import argparse
from distutils.version import LooseVersion
import logging
from pathlib import Path
import sys
Expand All @@ -11,6 +12,7 @@

import numpy as np
import torch
import torch.quantization
from typeguard import check_argument_types
from typeguard import check_return_type
from typing import List
Expand Down Expand Up @@ -79,11 +81,27 @@ def __init__(
nbest: int = 1,
streaming: bool = False,
enh_s2t_task: bool = False,
quantize_asr_model: bool = False,
quantize_lm: bool = False,
quantize_modules: List[str] = ["Linear"],
quantize_dtype: str = "qint8",
):
assert check_argument_types()

task = ASRTask if not enh_s2t_task else EnhS2TTask

if quantize_asr_model or quantize_lm:
if quantize_dtype == "float16" and torch.__version__ < LooseVersion(
"1.5.0"
):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch "
"version < 1.5.0. Switch to qint8 dtype instead."
)

quantize_modules = set([getattr(torch.nn, q) for q in quantize_modules])
quantize_dtype = getattr(torch, quantize_dtype)

# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = task.build_model_from_file(
Expand All @@ -103,6 +121,15 @@ def __init__(
)
asr_model.to(dtype=getattr(torch, dtype)).eval()

if quantize_asr_model:
logging.info("Use quantized asr model for decoding.")

asr_model = torch.quantization.quantize_dynamic(
asr_model,
qconfig_spec=quantize_modules,
dtype=quantize_dtype
)

decoder = asr_model.decoder

ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
Expand All @@ -118,6 +145,16 @@ def __init__(
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)

if quantize_lm:
logging.info("Use quantized lm for decoding.")

lm = torch.quantization.quantize_dynamic(
lm,
qconfig_spec=quantize_modules,
dtype=quantize_dtype
)

scorers["lm"] = lm.lm

# 3. Build ngram model
Expand Down Expand Up @@ -365,6 +402,10 @@ def inference(
transducer_conf: Optional[dict],
streaming: bool,
enh_s2t_task: bool,
quantize_asr_model: bool,
quantize_lm: bool,
quantize_modules: List[str],
quantize_dtype: str,
):
assert check_argument_types()
if batch_size > 1:
Expand Down Expand Up @@ -409,6 +450,10 @@ def inference(
nbest=nbest,
streaming=streaming,
enh_s2t_task=enh_s2t_task,
quantize_asr_model=quantize_asr_model,
quantize_lm=quantize_lm,
quantize_modules=quantize_modules,
quantize_dtype=quantize_dtype,
)
speech2text = Speech2Text.from_pretrained(
model_tag=model_tag,
Expand Down Expand Up @@ -557,6 +602,37 @@ def get_parser():
help="enhancement and asr joint model",
)

group = parser.add_argument_group("Quantization related")
group.add_argument(
"--quantize_asr_model",
type=str2bool,
default=False,
help="Apply dynamic quantization to ASR model.",
)
group.add_argument(
"--quantize_lm",
type=str2bool,
default=False,
help="Apply dynamic quantization to LM.",
)
group.add_argument(
"--quantize_modules",
type=str,
nargs="*",
default=["Linear"],
help="""List of modules to be dynamically quantized.
E.g.: --quantize_modules=[Linear,LSTM,GRU].
Each specified module should be an attribute of 'torch.nn', e.g.:
torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
)
group.add_argument(
"--quantize_dtype",
type=str,
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)

group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
Expand Down

0 comments on commit acb24c8

Please sign in to comment.