Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zipformer Onnx FP16 #1671

Merged
merged 5 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 128
--left-context-frames 128 \
--fp16 True

The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
Expand All @@ -73,6 +74,7 @@
import torch
import torch.nn as nn
from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
Expand Down Expand Up @@ -154,6 +156,13 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -479,7 +488,6 @@ def build_inputs_outputs(tensors, i):

add_meta_data(filename=encoder_filename, meta_data=meta_data)


def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
Expand Down Expand Up @@ -747,11 +755,29 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
Expand Down
30 changes: 28 additions & 2 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"

--left-context-frames "64,128,256,-1" \
--fp16 True
It will generate the following 3 files inside $repo/exp:

- encoder-epoch-99-avg-1.onnx
Expand All @@ -70,6 +70,7 @@
import torch
import torch.nn as nn
from decoder import Decoder
from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
Expand Down Expand Up @@ -151,6 +152,13 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)

parser.add_argument(
"--fp16",
type=str2bool,
default=False,
help="Whether to export models in fp16",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -584,6 +592,24 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ onnx>=1.15.0
onnxruntime>=1.16.3
onnxoptimizer
onnxsim
onnxconverter_common

# style check session:
black==22.3.0
Expand Down
Loading