diff --git a/espnet2/bin/asr_inference.py b/espnet2/bin/asr_inference.py index fc6d75cb488..810f8714a9f 100755 --- a/espnet2/bin/asr_inference.py +++ b/espnet2/bin/asr_inference.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import argparse +from distutils.version import LooseVersion import logging from pathlib import Path import sys @@ -11,6 +12,7 @@ import numpy as np import torch +import torch.quantization from typeguard import check_argument_types from typeguard import check_return_type from typing import List @@ -79,11 +81,27 @@ def __init__( nbest: int = 1, streaming: bool = False, enh_s2t_task: bool = False, + quantize_asr_model: bool = False, + quantize_lm: bool = False, + quantize_modules: List[str] = ["Linear"], + quantize_dtype: str = "qint8", ): assert check_argument_types() task = ASRTask if not enh_s2t_task else EnhS2TTask + if quantize_asr_model or quantize_lm: + if quantize_dtype == "float16" and torch.__version__ < LooseVersion( + "1.5.0" + ): + raise ValueError( + "float16 dtype for dynamic quantization is not supported with torch " + "version < 1.5.0. Switch to qint8 dtype instead." + ) + + quantize_modules = set([getattr(torch.nn, q) for q in quantize_modules]) + quantize_dtype = getattr(torch, quantize_dtype) + # 1. Build ASR model scorers = {} asr_model, asr_train_args = task.build_model_from_file( @@ -103,6 +121,15 @@ def __init__( ) asr_model.to(dtype=getattr(torch, dtype)).eval() + if quantize_asr_model: + logging.info("Use quantized asr model for decoding.") + + asr_model = torch.quantization.quantize_dynamic( + asr_model, + qconfig_spec=quantize_modules, + dtype=quantize_dtype + ) + decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) @@ -118,6 +145,16 @@ def __init__( lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) + + if quantize_lm: + logging.info("Use quantized lm for decoding.") + + lm = torch.quantization.quantize_dynamic( + lm, + qconfig_spec=quantize_modules, + dtype=quantize_dtype + ) + scorers["lm"] = lm.lm # 3. Build ngram model @@ -365,6 +402,10 @@ def inference( transducer_conf: Optional[dict], streaming: bool, enh_s2t_task: bool, + quantize_asr_model: bool, + quantize_lm: bool, + quantize_modules: List[str], + quantize_dtype: str, ): assert check_argument_types() if batch_size > 1: @@ -409,6 +450,10 @@ def inference( nbest=nbest, streaming=streaming, enh_s2t_task=enh_s2t_task, + quantize_asr_model=quantize_asr_model, + quantize_lm=quantize_lm, + quantize_modules=quantize_modules, + quantize_dtype=quantize_dtype, ) speech2text = Speech2Text.from_pretrained( model_tag=model_tag, @@ -557,6 +602,37 @@ def get_parser(): help="enhancement and asr joint model", ) + group = parser.add_argument_group("Quantization related") + group.add_argument( + "--quantize_asr_model", + type=str2bool, + default=False, + help="Apply dynamic quantization to ASR model.", + ) + group.add_argument( + "--quantize_lm", + type=str2bool, + default=False, + help="Apply dynamic quantization to LM.", + ) + group.add_argument( + "--quantize_modules", + type=str, + nargs="*", + default=["Linear"], + help="""List of modules to be dynamically quantized. + E.g.: --quantize_modules=[Linear,LSTM,GRU]. + Each specified module should be an attribute of 'torch.nn', e.g.: + torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", + ) + group.add_argument( + "--quantize_dtype", + type=str, + default="qint8", + choices=["float16", "qint8"], + help="Dtype for dynamic quantization.", + ) + group = parser.add_argument_group("Beam-search related") group.add_argument( "--batch_size",