Skip to content

Commit

Permalink
Edit onnx parser to infer values in post order (#5755)
Browse files Browse the repository at this point in the history
* edit onnx parser to infer values in post order to speed up onnx imports with many calls to infer_value

* fix pylint
  • Loading branch information
Matthew Brookhart authored Jun 12, 2020
1 parent 456ecc6 commit 995b9ff
Showing 1 changed file with 116 additions and 3 deletions.
119 changes: 116 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,29 @@
from .. import function as _function
from .. import op as _op
from .. import vision as _vision

from ..function import Function
from ..expr import Call, Let
from ..expr import If, Tuple, TupleGetItem
from ..expr import RefCreate, RefRead, RefWrite
from ..expr_functor import ExprFunctor
from ..adt import Match, Clause

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import infer_type, infer_value, infer_value_simulated, get_name
from .common import infer_type, get_name
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated

__all__ = ['from_onnx']

g = None

def infer_value(input_val, params, mod=None):
return g.infer_value(input_val, params, mod)

def infer_value_simulated(input_val, params):
return g.infer_value_simulated(input_val, params)

class onnx_input():
""" Dual purpose list or dictionary access object."""
Expand Down Expand Up @@ -1891,8 +1908,7 @@ def _get_convert_map(opset):
'NonZero': NonZero.get_converter(opset),
}


class GraphProto(object):
class GraphProto(ExprFunctor):
"""A helper class for handling Relay expression copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
Expand All @@ -1914,6 +1930,101 @@ def __init__(self, shape, dtype):
self._shape = shape if shape else {}
self._dtype = dtype

#For infering Values
self._tmp_params = {}
self._infer_simulated = True
self._mod = None
super(GraphProto, self).__init__()

def infer_value(self, input_val, params, mod=None):
self._tmp_params = params
self._infer_simulated = False
self._mod = mod
return self.visit(input_val).data
#return _infer_value(input_val, params, mod)

def infer_value_simulated(self, input_val, params):
self._tmp_params = params
self._infer_simulated = True
return self.visit(input_val).data
#return _infer_value_simulated(input_val, params)

def infer(self, expr):
if self._infer_simulated:
out = _infer_value_simulated(expr, self._tmp_params)
else:
out = _infer_value(expr, self._tmp_params)
return _expr.const(out.asnumpy())

def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
return self.infer(Function(
list(new_params),
new_body,
fn.ret_type,
fn.type_params,
fn.attrs))

def visit_let(self, let):
newvar = self.visit(let.var)
newval = self.visit(let.value)
newbody = self.visit(let.body)
return self.infer(Let(newvar, newval, newbody))

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return self.infer(Call(new_fn, new_args, call.attrs))

def visit_var(self, var):
return self.infer(var)

def visit_global_id(self, global_var):
return self.infer(global_var)

def visit_if(self, ite):
return self.infer(If(
self.visit(ite.cond),
self.visit(ite.true_branch),
self.visit(ite.false_branch)))

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return self.infer(TupleGetItem(tuple_value, op.index))
return self.infer(op)

def visit_global_var(self, gvar):
return self.infer(gvar)

def visit_op(self, op):
return op

def visit_constant(self, const):
return const

def visit_constructor(self, con):
return con

def visit_match(self, m):
return self.infer(Match(
self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
complete=m.complete))

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))

def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))

def from_onnx(self, graph, opset):
"""Construct Relay expression from ONNX graph.
Expand Down Expand Up @@ -2172,6 +2283,7 @@ def from_onnx(model,
warnings.warn(str(e))
except ImportError:
pass
global g
g = GraphProto(shape, dtype)
graph = model.graph
if opset is None:
Expand All @@ -2180,4 +2292,5 @@ def from_onnx(model,
except AttributeError:
opset = 1
mod, params = g.from_onnx(graph, opset)
g = None
return mod, params

0 comments on commit 995b9ff

Please sign in to comment.