Skip to content

Commit

Permalink
Merging onnx models (#518)
Browse files Browse the repository at this point in the history
* add export function of onnx-all-in-one to export.py

* add onnx_check script for all-in-one onnx model

* minor fix

* remove unused arguments

* add onnx-all-in-one test

* fix style

* fix style

* fix requirements

* fix input/output names

* fix installing onnx_graphsurgeon

* fix instaliing onnx_graphsurgeon

* revert to previous requirements.txt

* fix minor
  • Loading branch information
EmreOzkose authored Aug 4, 2022
1 parent a4dd273 commit 7157f62
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ log "Decode with ONNX models"
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx

./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx

./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
Expand Down
33 changes: 33 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
import logging
from pathlib import Path

import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
Expand Down Expand Up @@ -512,6 +513,30 @@ def export_joiner_model_onnx(
logging.info(f"Saved to {joiner_filename}")


def export_all_in_one_onnx(
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)

encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")

combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")


@torch.no_grad()
def main():
args = get_parser().parse_args()
Expand Down Expand Up @@ -603,6 +628,14 @@ def main():
joiner_filename,
opset_version=opset_version,
)

all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename,
)
elif params.jit is True:
logging.info("Using torch.jit.script()")
# We won't use the forward() method of the model in C++, so just ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""

import argparse
import logging
import os

import onnx
import onnx_graphsurgeon as gs
import onnxruntime
import onnxruntime as ort
import torch

ort.set_default_logger_severity(3)


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)

parser.add_argument(
"--onnx-all-in-one-filename",
required=True,
type=str,
help="Path to the onnx all in one model",
)

return parser


def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
encoder_input_names = [i.name for i in encoder_inputs]
encoder_output_names = [i.name for i in encoder_session.get_outputs()]

for N in [1, 5]:
for T in [12, 25]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T

encoder_inputs = {
encoder_input_names[0]: x.numpy(),
encoder_input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
[encoder_output_names[1], encoder_output_names[0]],
encoder_inputs,
)

torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)

encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)


def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].shape == ["N", 2]
decoder_input_names = [i.name for i in decoder_inputs]
decoder_output_names = [i.name for i in decoder_session.get_outputs()]

for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))

decoder_inputs = {decoder_input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
[decoder_output_names[0]],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)

torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)


def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_input_names = [i.name for i in joiner_inputs]
joiner_output_names = [i.name for i in joiner_session.get_outputs()]

for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)

joiner_inputs = {
joiner_input_names[0]: encoder_out.numpy(),
joiner_input_names[1]: decoder_out.numpy(),
}
joiner_out = joiner_session.run(
[joiner_output_names[0]], joiner_inputs
)[0]
joiner_out = torch.from_numpy(joiner_out)

torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)


def extract_sub_model(
onnx_graph: onnx.ModelProto,
input_op_names: list,
output_op_names: list,
non_verbose=False,
):
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()

# Extraction of input OP and output OP
graph_node_inputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_input in graph_nodes.inputs
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_output in graph_nodes.outputs
if graph_nodes_output.name in output_op_names
]

# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()

# Update graph INPUT/OUTPUT
graph.inputs = [
graph_node_input
for graph_node in graph_node_inputs
for graph_node_input in graph_node.inputs
if graph_node_input.shape
]
graph.outputs = [
graph_node_output
for graph_node in graph_node_outputs
for graph_node_output in graph_node.outputs
]

# Cleanup
graph.cleanup().toposort()

# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(
gs.export_onnx(graph)
)
except Exception:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
"WARNING: "
+ "The input shape of the next OP does not match the output shape. "
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
)
return extracted_graph


def extract_encoder(onnx_model: onnx.ModelProto):
encoder_ = extract_sub_model(
onnx_model,
["encoder/x", "encoder/x_lens"],
["encoder/encoder_out", "encoder/encoder_out_lens"],
False,
)
onnx.save(encoder_, "tmp_encoder.onnx")
onnx.checker.check_model(encoder_)
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
os.remove("tmp_encoder.onnx")
return sess


def extract_decoder(onnx_model: onnx.ModelProto):
decoder_ = extract_sub_model(
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
)
onnx.save(decoder_, "tmp_decoder.onnx")
onnx.checker.check_model(decoder_)
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
os.remove("tmp_decoder.onnx")
return sess


def extract_joiner(onnx_model: onnx.ModelProto):
joiner_ = extract_sub_model(
onnx_model,
["joiner/encoder_out", "joiner/decoder_out"],
["joiner/logit"],
False,
)
onnx.save(joiner_, "tmp_joiner.onnx")
onnx.checker.check_model(joiner_)
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
os.remove("tmp_joiner.onnx")
return sess


@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))

model = torch.jit.load(args.jit_filename)
onnx_model = onnx.load(args.onnx_all_in_one_filename)

options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1

logging.info("Test encoder")
encoder_session = extract_encoder(onnx_model)
test_encoder(model, encoder_session)

logging.info("Test decoder")
decoder_session = extract_decoder(onnx_model)
test_decoder(model, decoder_session)

logging.info("Test joiner")
joiner_session = extract_joiner(onnx_model)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")


if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)

logging.basicConfig(format=formatter, level=logging.INFO)
main()
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ multi_quantization

onnx
onnxruntime
onnx_graphsurgeon -i https://pypi.ngc.nvidia.com
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ typeguard
multi_quantization
onnx
onnxruntime
--extra-index-url https://pypi.ngc.nvidia.com
onnx_graphsurgeon

0 comments on commit 7157f62

Please sign in to comment.