forked from microsoft/onnxruntime-extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compose.py
146 lines (126 loc) · 5.89 KB
/
compose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import io
import onnx
import torch
import numpy
from torch.onnx import TrainingMode, export as _export
from ._ortapi2 import OrtPyFunction
from .pnp import ONNXModelUtils, ProcessingModule
def _is_numpy_object(x):
return isinstance(x, (numpy.ndarray, numpy.generic))
def _is_numpy_string_type(arr):
return arr.dtype.kind in {'U', 'S'}
def _export_f(model, args=None,
export_params=True, verbose=False,
input_names=None, output_names=None,
operator_export_type=None, opset_version=None,
do_constant_folding=True,
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None):
with io.BytesIO() as f:
_export(model, args, f,
export_params=export_params, verbose=verbose,
training=TrainingMode.EVAL, input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets)
return onnx.load_model(io.BytesIO(f.getvalue()))
class ONNXCompose:
"""
Merge the pre- and post-processing PyTorch subclassing modules with the core model.
:arg models the core model, can be an ONNX model or a PyTorch ONNX-exportable models
:arg preprocessors the preprocessing module
:arg postprocessors the postprocessing module
"""
def __init__(self, models=None, preprocessors=None, postprocessors=None):
assert isinstance(preprocessors, ProcessingModule),\
'preprocessors must be subclassing from ProcessingModule'
assert postprocessors is None or isinstance(postprocessors, ProcessingModule),\
'postprocessors must be subclassing from ProcessingModule'
self.models = models
self.preprocessors = preprocessors
self.postprocessors = postprocessors
self.pre_args = None
self.models_args = None
self.post_args = None
def export(self, opset_version, output_file=None,
export_params=True,
verbose=False,
input_names=None,
output_names=None,
operator_export_type=None,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
io_mapping=None):
"""
export all models and modules into a merged ONNX model.
"""
post_m = None
pre_m = self.preprocessors.export(opset_version, *tuple(self.pre_args), ofname=output_file)
if isinstance(self.models, torch.nn.Module):
core = _export_f(self.models, tuple(self.models_args),
export_params=export_params, verbose=verbose, input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type, opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets)
else:
core = self.models
if self.postprocessors is not None:
post_m = self.postprocessors.export(opset_version, *tuple(self.post_args))
model_l = [core]
if pre_m is not None:
model_l.insert(0, pre_m)
if post_m is not None:
model_l.append(post_m)
full_m = ONNXModelUtils.join_models(*model_l, io_mapping=io_mapping)
if output_file is not None:
onnx.save_model(full_m, output_file)
return full_m
def predict(self, *args, extra_args_post=None):
"""
Predict the result through all modules/models
:param args: the input arguments for the first preprocessing module.
:param extra_args_post: extra args for post-processors.
:return: the result from the last postprocessing module or
from the core model if there is no postprocessing module.
"""
def _is_tensor(x):
if isinstance(x, list):
return all(_is_tensor(_x) for _x in x)
return isinstance(x, torch.Tensor)
def _is_array(x):
if isinstance(x, list):
return all(_is_array(_x) for _x in x)
return _is_numpy_object(x) and (not _is_numpy_string_type(x))
# convert the raw value, and special handling for string.
n_args = [numpy.array(_arg) if not _is_tensor(_arg) else _arg for _arg in args]
n_args = [torch.from_numpy(_arg) if
_is_array(_arg) else _arg for _arg in n_args]
self.pre_args = n_args
inputs = [self.preprocessors.forward(*n_args)]
flatten_inputs = []
for _i in inputs:
flatten_inputs += list(_i) if isinstance(_i, tuple) else [_i]
self.models_args = flatten_inputs
if isinstance(self.models, torch.nn.Module):
outputs = self.models.forward(*flatten_inputs)
else:
f = OrtPyFunction.from_model(self.models)
outputs = [torch.from_numpy(f(*[_i.numpy() for _i in flatten_inputs]))]
self.post_args = outputs
if extra_args_post:
extra_args = []
if extra_args_post[0]:
extra_args = n_args[extra_args_post[0][0]:extra_args_post[0][1]]
if len(extra_args_post) > 1:
extra_args += flatten_inputs[extra_args_post[1][0]:extra_args_post[1][1]]
self.post_args = extra_args + self.post_args
if self.postprocessors is None:
return outputs
return self.postprocessors.forward(*self.post_args)