diff --git a/espnet2/bin/asr_inference.py b/espnet2/bin/asr_inference.py index 810f8714a9f..b6f86e6aa43 100755 --- a/espnet2/bin/asr_inference.py +++ b/espnet2/bin/asr_inference.py @@ -123,11 +123,9 @@ def __init__( if quantize_asr_model: logging.info("Use quantized asr model for decoding.") - + asr_model = torch.quantization.quantize_dynamic( - asr_model, - qconfig_spec=quantize_modules, - dtype=quantize_dtype + asr_model, qconfig_spec=quantize_modules, dtype=quantize_dtype ) decoder = asr_model.decoder @@ -150,9 +148,7 @@ def __init__( logging.info("Use quantized lm for decoding.") lm = torch.quantization.quantize_dynamic( - lm, - qconfig_spec=quantize_modules, - dtype=quantize_dtype + lm, qconfig_spec=quantize_modules, dtype=quantize_dtype ) scorers["lm"] = lm.lm diff --git a/test/espnet2/bin/test_asr_inference.py b/test/espnet2/bin/test_asr_inference.py index 1946e6ce475..c64a0c774eb 100644 --- a/test/espnet2/bin/test_asr_inference.py +++ b/test/espnet2/bin/test_asr_inference.py @@ -87,10 +87,10 @@ def test_Speech2Text(asr_config_file, lm_config_file): @pytest.mark.execution_timeout(5) def test_Speech2Text_quantized(asr_config_file, lm_config_file): speech2text = Speech2Text( - asr_train_config=asr_config_file, - lm_train_config=lm_config_file, + asr_train_config=asr_config_file, + lm_train_config=lm_config_file, beam_size=1, - quantize_asr_model=True, + quantize_asr_model=True, quantize_lm=True, ) speech = np.random.randn(100000)