From 06e2a7a16a06cda326035d03c84734d18c852cd3 Mon Sep 17 00:00:00 2001 From: Yifan Peng Date: Mon, 9 May 2022 23:10:14 -0400 Subject: [PATCH] apply black --- espnet2/bin/asr_inference.py | 10 +++------- test/espnet2/bin/test_asr_inference.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/espnet2/bin/asr_inference.py b/espnet2/bin/asr_inference.py index 810f8714a9f..b6f86e6aa43 100755 --- a/espnet2/bin/asr_inference.py +++ b/espnet2/bin/asr_inference.py @@ -123,11 +123,9 @@ def __init__( if quantize_asr_model: logging.info("Use quantized asr model for decoding.") - + asr_model = torch.quantization.quantize_dynamic( - asr_model, - qconfig_spec=quantize_modules, - dtype=quantize_dtype + asr_model, qconfig_spec=quantize_modules, dtype=quantize_dtype ) decoder = asr_model.decoder @@ -150,9 +148,7 @@ def __init__( logging.info("Use quantized lm for decoding.") lm = torch.quantization.quantize_dynamic( - lm, - qconfig_spec=quantize_modules, - dtype=quantize_dtype + lm, qconfig_spec=quantize_modules, dtype=quantize_dtype ) scorers["lm"] = lm.lm diff --git a/test/espnet2/bin/test_asr_inference.py b/test/espnet2/bin/test_asr_inference.py index 1946e6ce475..c64a0c774eb 100644 --- a/test/espnet2/bin/test_asr_inference.py +++ b/test/espnet2/bin/test_asr_inference.py @@ -87,10 +87,10 @@ def test_Speech2Text(asr_config_file, lm_config_file): @pytest.mark.execution_timeout(5) def test_Speech2Text_quantized(asr_config_file, lm_config_file): speech2text = Speech2Text( - asr_train_config=asr_config_file, - lm_train_config=lm_config_file, + asr_train_config=asr_config_file, + lm_train_config=lm_config_file, beam_size=1, - quantize_asr_model=True, + quantize_asr_model=True, quantize_lm=True, ) speech = np.random.randn(100000)