From c3462126eb11788e5d042afcd7d56349c452a82f Mon Sep 17 00:00:00 2001 From: Florin Blanaru Date: Tue, 14 Mar 2023 15:15:02 +0200 Subject: [PATCH] [TUZ-157] Add span information for all ops used in the ONNX frontend (#34) This PR adds Span information to the IRModule generated by the ONXN frontend --------- Co-authored-by: Josh Fromm --- python/tvm/relax/frontend/__init__.py | 2 +- python/tvm/relax/frontend/common.py | 70 ++- .../tvm/relax/frontend/onnx/onnx_frontend.py | 431 ++++++++++-------- src/relax/ir/block_builder.cc | 2 +- .../relax/frontend/test_onnx_frontend.py | 53 ++- 5 files changed, 367 insertions(+), 191 deletions(-) diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py index 4baf3195f032..480200a09802 100644 --- a/python/tvm/relax/frontend/__init__.py +++ b/python/tvm/relax/frontend/__init__.py @@ -17,4 +17,4 @@ """ Frontends for constructing Relax programs, with the model importers """ -from .common import detach_params +from .common import detach_params, SpanContext, attach_span, emit_te_with_span diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index 9904324df40e..449479f824a1 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -16,9 +16,10 @@ # under the License. # pylint: disable=invalid-name """Commons for Relax frontend.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union, Callable, Any import tvm +from ...ir import Span, SourceName def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]: @@ -53,3 +54,70 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n else: detached_mod[gv] = func return detached_mod, params_dict + + +def emit_te_with_span(bb, func: Callable, *args: Any, **kwargs: Any) -> tvm.relax.Var: + """Same as block_builder.emit_te, but attaches a span to the generated call. + Uses the current span in the SpanContext. + """ + + call = bb.call_te(func, *args, **kwargs) + call = attach_span(call) + return bb.emit(call) + + +def attach_span(op: tvm.relax.Call): + """Attach a span to a Relax op if it doesn't already have one. + Uses the current span in the SpanContext. + + Parameters + ---------- + op : tvm.relax.Expr + The op to attach a span to. + + Returns + ------- + op : tvm.relax.Expr + The op with a span attached. + """ + assert isinstance(op, tvm.relax.Call), "Expected a Call node but got: {op}".format( + op=str(type(op)) + ) + if op.span is None: + return tvm.relax.Call(op.op, op.args, op.attrs, op.sinfo_args, SpanContext.current()) + return op + + +class SpanContext: + """A context manager for setting the current Span. + + Parameters + ---------- + span : Union[Span, str] + The span to set as the current span. + """ + + __current_span = None + + def __init__(self, span: Union[Span, str]): + assert isinstance(span, (Span, str)), "span must be a Span or str" + if isinstance(span, str): + span = Span(SourceName(span), 0, 0, 0, 0) + SpanContext.__current_span = span + + def __enter__(self): + return self + + def __exit__(self, ptype, value, trace): + SpanContext.__current_span = None + + @staticmethod + def current(): + """Get the span in the current context. + + Returns + ------- + span : Optional[Span] + The current span. + """ + return SpanContext.__current_span diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index e15f605d2441..c2169b425206 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -45,6 +45,7 @@ from tvm.ir import IRModule from tvm.ir.supply import NameSupply from tvm.relax import testing +from tvm.relax.frontend.common import attach_span, emit_te_with_span def get_type(elem_type: Union[str, int]) -> str: @@ -157,7 +158,7 @@ class MatMul(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.matmul(inputs[0], inputs[1]) + return attach_span(relax.op.matmul(inputs[0], inputs[1])) class Div(OnnxOpConverter): @@ -165,7 +166,7 @@ class Div(OnnxOpConverter): @classmethod def _impl_v14(cls, bb, inputs, attr): - return relax.op.divide(inputs[0], inputs[1]) + return attach_span(relax.op.divide(inputs[0], inputs[1])) class Sigmoid(OnnxOpConverter): @@ -173,7 +174,7 @@ class Sigmoid(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.sigmoid(inputs[0]) + return attach_span(relax.op.sigmoid(inputs[0])) class Softmax(OnnxOpConverter): @@ -182,7 +183,7 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", -1) - return relax.op.nn.softmax(inputs[0], axis=axis) + return attach_span(relax.op.nn.softmax(inputs[0], axis=axis)) class Transpose(OnnxOpConverter): @@ -191,7 +192,7 @@ class Transpose(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): axes = attr.get("perm", None) - return relax.op.permute_dims(inputs[0], axes) + return attach_span(relax.op.permute_dims(inputs[0], axes)) class Unsqueeze(OnnxOpConverter): @@ -200,7 +201,7 @@ class Unsqueeze(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr): axes = attr.get("axes") - return relax.op.expand_dims(inputs[0], axes) + return attach_span(relax.op.expand_dims(inputs[0], axes)) @classmethod def _impl_v13(cls, bb, inputs, attr): @@ -211,7 +212,7 @@ def _impl_v13(cls, bb, inputs, attr): constant_axes = list(map(int, constant_axes)) constant_axes = sorted(constant_axes) for axis in constant_axes: - data = relax.op.expand_dims(data, axis=axis) + data = attach_span(relax.op.expand_dims(data, axis=axis)) return data raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") @@ -223,7 +224,7 @@ class Concat(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", 0) - return relax.op.concat(inputs, axis=axis) + return attach_span(relax.op.concat(inputs, axis=axis)) class Add(OnnxOpConverter): @@ -231,7 +232,7 @@ class Add(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.add(inputs[0], inputs[1]) + return attach_span(relax.op.add(inputs[0], inputs[1])) class Mul(OnnxOpConverter): @@ -239,7 +240,7 @@ class Mul(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.multiply(inputs[0], inputs[1]) + return attach_span(relax.op.multiply(inputs[0], inputs[1])) class Cast(OnnxOpConverter): @@ -248,7 +249,7 @@ class Cast(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): to_type = get_type(attr["to"]) - return relax.op.astype(inputs[0], to_type) + return attach_span(relax.op.astype(inputs[0], to_type)) class Gather(OnnxOpConverter): @@ -263,13 +264,13 @@ def _impl_v13(cls, bb, inputs, attr): scalar_indices = False if len(indices.struct_info.shape) == 0: scalar_indices = True - indices = bb.normalize(relax.op.expand_dims(indices, axis=0)) + indices = bb.normalize(attach_span(relax.op.expand_dims(indices, axis=0))) axis = attr.get("axis", 0) - out = relax.op.take(data, indices, axis) + out = attach_span(relax.op.take(data, indices, axis)) # If indices were scalar, output dimension needs to be reduced. if scalar_indices: - out = relax.op.squeeze(out, axis) + out = attach_span(relax.op.squeeze(out, axis)) return out @@ -290,18 +291,18 @@ def _impl_v13(cls, bb, inputs, attr): # Compute Y = alpha * A X B + beta * C if alpha is not None: - A = bb.normalize(relax.op.multiply(A, relax.const(alpha, dtype=dtype))) + A = bb.normalize(attach_span(relax.op.multiply(A, relax.const(alpha, dtype=dtype)))) if transA: - A = relax.op.permute_dims(A, [1, 0]) + A = attach_span(relax.op.permute_dims(A, [1, 0])) if transB: - B = relax.op.permute_dims(B, [1, 0]) - Y = bb.normalize(relax.op.matmul(A, B)) + B = attach_span(relax.op.permute_dims(B, [1, 0])) + Y = bb.normalize(attach_span(relax.op.matmul(A, B))) if C is not None: if beta is not None: - C = bb.normalize(relax.op.multiply(C, relax.const(beta, dtype=dtype))) - Y = relax.op.add(Y, C) + C = bb.normalize(attach_span(relax.op.multiply(C, relax.const(beta, dtype=dtype)))) + Y = attach_span(relax.op.add(Y, C)) return Y @@ -315,7 +316,7 @@ def _impl_v13(cls, bb, inputs, attr): new_shape = inputs[1] if isinstance(inputs[1], relax.Constant): new_shape = inputs[1].data.numpy().tolist() - return relax.op.reshape(data, new_shape) + return attach_span(relax.op.reshape(data, new_shape)) class Gelu(OnnxOpConverter): @@ -326,7 +327,7 @@ class Gelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr): - return relax.op.nn.gelu(inputs[0]) + return attach_span(relax.op.nn.gelu(inputs[0])) class BiasGelu(OnnxOpConverter): @@ -337,8 +338,8 @@ class BiasGelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr): - inp = relax.op.add(inputs[0], inputs[1]) - return relax.op.nn.gelu(inp) + inp = attach_span(relax.op.add(inputs[0], inputs[1])) + return attach_span(relax.op.nn.gelu(inp)) class Where(OnnxOpConverter): @@ -346,7 +347,7 @@ class Where(OnnxOpConverter): @classmethod def _impl_v16(cls, bb, inputs, attr): - return relax.op.where(inputs[0], inputs[1], inputs[2]) + return attach_span(relax.op.where(inputs[0], inputs[1], inputs[2])) class Clip(OnnxOpConverter): @@ -356,9 +357,9 @@ class Clip(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr): results = inputs[0] if inputs[1] is not None: - results = bb.emit_te(topi.maximum, results, inputs[1]) + results = emit_te_with_span(bb, topi.maximum, results, inputs[1]) if inputs[2] is not None: - results = bb.emit_te(topi.minimum, results, inputs[2]) + results = emit_te_with_span(bb, topi.minimum, results, inputs[2]) return results @@ -367,7 +368,7 @@ class Equal(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.equal(inputs[0], inputs[1]) + return attach_span(relax.op.equal(inputs[0], inputs[1])) class Shape(OnnxOpConverter): @@ -381,7 +382,7 @@ def _impl_v13(cls, bb, inputs, attr): data_shape = [i.value for i in inputs[0].struct_info.shape] return relax.const(data_shape, "int64") # Otherwise compute it dynamically. - return relax.op.shape_of(inputs[0]) + return attach_span(relax.op.shape_of(inputs[0])) class Not(OnnxOpConverter): @@ -389,7 +390,7 @@ class Not(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.bitwise_not, inputs[0]) + return emit_te_with_span(bb, topi.bitwise_not, inputs[0]) class Tanh(OnnxOpConverter): @@ -397,7 +398,7 @@ class Tanh(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.tanh(inputs[0]) + return attach_span(relax.op.tanh(inputs[0])) class Sqrt(OnnxOpConverter): @@ -405,7 +406,7 @@ class Sqrt(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.sqrt(inputs[0]) + return attach_span(relax.op.sqrt(inputs[0])) class Relu(OnnxOpConverter): @@ -413,7 +414,7 @@ class Relu(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.nn.relu(inputs[0]) + return attach_span(relax.op.nn.relu(inputs[0])) class Pow(OnnxOpConverter): @@ -421,7 +422,7 @@ class Pow(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.power(inputs[0], inputs[1]) + return attach_span(relax.op.power(inputs[0], inputs[1])) class Conv(OnnxOpConverter): @@ -430,32 +431,48 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr): ndim = len(inputs[0].struct_info.shape) - if ndim == 4: + if ndim == 3: + conv_out = emit_te_with_span( + bb, + topi.nn.conv1d, + inputs[0], + inputs[1], + attr.get("strides", 1), + attr.get("pads", 0), + attr.get("dilation", 1), + "NCHW", + "OIHW", + ) + elif ndim == 4: conv_out = bb.normalize( - relax.op.nn.conv2d( - data=inputs[0], - weight=inputs[1], - strides=attr.get("strides", 1), - padding=attr.get("pads", 0), - dilation=attr.get("dilation", 1), - groups=attr.get("group", 1), - data_layout="NCHW", - kernel_layout="OIHW", + attach_span( + relax.op.nn.conv2d( + data=inputs[0], + weight=inputs[1], + strides=attr.get("strides", 1), + padding=attr.get("pads", 0), + dilation=attr.get("dilation", 1), + groups=attr.get("group", 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) ) ) else: raise NotImplementedError("Only 2d conv currently supported.") if inputs[2] is not None: - bias = relax.op.reshape( - inputs[2], - [1, -1] - + [ - 1, - ] - * (ndim - 2), + bias = attach_span( + relax.op.reshape( + inputs[2], + [1, -1] + + [ + 1, + ] + * (ndim - 2), + ) ) - conv_out = relax.op.add(conv_out, bias) + conv_out = attach_span(relax.op.add(conv_out, bias)) return conv_out @@ -466,14 +483,17 @@ class Erf(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): x = inputs[0] - sqrt2 = relax.op.sqrt(relax.const(2, x.struct_info.dtype)) + sqrt2 = attach_span(relax.op.sqrt(relax.const(2, x.struct_info.dtype))) # TODO: replace with erf operator once it is implemented + mul = attach_span(relax.op.multiply(x, sqrt2)) + gelu = attach_span(relax.op.nn.gelu(mul)) + mul_2 = attach_span(relax.op.multiply(gelu, sqrt2)) return bb.normalize( - relax.op.add( - relax.op.divide( - relax.op.multiply(relax.op.nn.gelu(relax.op.multiply(x, sqrt2)), sqrt2), x - ), - relax.const(-1, x.struct_info.dtype), + attach_span( + relax.op.add( + attach_span(relax.op.divide(mul_2, x)), + relax.const(-1, x.struct_info.dtype), + ) ) ) @@ -489,15 +509,16 @@ def _impl_v13(cls, bb, inputs, attr): else: axis = None if attr.get("reverse", 0) != 0: - data = bb.emit_te(topi.flip, data, axis=axis if axis else 0) - data = bb.emit_te( + data = emit_te_with_span(bb, topi.flip, data, axis=axis if axis else 0) + data = emit_te_with_span( + bb, topi.cumsum, data=data, axis=axis, exclusive=attr.get("exclusive", None), ) if attr.get("reverse", 0) != 0: - data = bb.emit_te(topi.flip, data, axis=axis if axis else 0) + data = emit_te_with_span(bb, topi.flip, data, axis=axis if axis else 0) return data @@ -509,7 +530,7 @@ def _impl_v13(cls, bb, inputs, attr): axis = inputs[1] if axis is not None: axis = [int(x) for x in inputs[1].data.numpy()] - return relax.op.squeeze(inputs[0], axis) + return attach_span(relax.op.squeeze(inputs[0], axis)) class Constant(OnnxOpConverter): @@ -559,7 +580,7 @@ def _impl_v9(cls, bb, inputs, attr): for i in range(shape_ndim): shape_vars.append(tvm.tir.Var("x_%d" % i, "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) - return relax.op.broadcast_to(const_value, relax.ShapeExpr(shape_vars)) + return attach_span(relax.op.broadcast_to(const_value, relax.ShapeExpr(shape_vars))) class Sub(OnnxOpConverter): @@ -567,7 +588,7 @@ class Sub(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.subtract(inputs[0], inputs[1]) + return attach_span(relax.op.subtract(inputs[0], inputs[1])) class Sin(OnnxOpConverter): @@ -575,7 +596,7 @@ class Sin(OnnxOpConverter): @classmethod def _impl_v7(cls, bb, inputs, attr): - return relax.op.sin(inputs[0]) + return attach_span(relax.op.sin(inputs[0])) class Cos(OnnxOpConverter): @@ -583,7 +604,7 @@ class Cos(OnnxOpConverter): @classmethod def _impl_v7(cls, bb, inputs, attr): - return relax.op.cos(inputs[0]) + return attach_span(relax.op.cos(inputs[0])) class Neg(OnnxOpConverter): @@ -591,7 +612,7 @@ class Neg(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.negative(inputs[0]) + return attach_span(relax.op.negative(inputs[0])) class Abs(OnnxOpConverter): @@ -599,7 +620,7 @@ class Abs(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.abs(inputs[0]) + return attach_span(relax.op.abs(inputs[0])) class Min(OnnxOpConverter): @@ -608,9 +629,9 @@ class Min(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): # Expand inputs, stack them, then perform minimum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.min(stacked_tensor, axis=0) + inputs = [bb.normalize(attach_span(relax.op.expand_dims(i, axis=0))) for i in inputs] + stacked_tensor = attach_span(relax.op.concat(inputs, axis=0)) + return attach_span(relax.op.min(stacked_tensor, axis=0)) class Max(OnnxOpConverter): @@ -619,9 +640,9 @@ class Max(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): # Expand inputs, stack them, then perform maximum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) - return relax.op.max(stacked_tensor, axis=0) + inputs = [bb.normalize(attach_span(relax.op.expand_dims(i, axis=0))) for i in inputs] + stacked_tensor = attach_span(relax.op.concat(inputs, axis=0)) + return attach_span(relax.op.max(stacked_tensor, axis=0)) class Log(OnnxOpConverter): @@ -629,7 +650,7 @@ class Log(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.log(inputs[0]) + return attach_span(relax.op.log(inputs[0])) class Less(OnnxOpConverter): @@ -637,7 +658,7 @@ class Less(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.less(inputs[0], inputs[1]) + return attach_span(relax.op.less(inputs[0], inputs[1])) class LessOrEqual(OnnxOpConverter): @@ -645,7 +666,7 @@ class LessOrEqual(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return relax.op.less_equal(inputs[0], inputs[1]) + return attach_span(relax.op.less_equal(inputs[0], inputs[1])) class Split(OnnxOpConverter): @@ -663,7 +684,7 @@ def _impl_v1(cls, bb, inputs, attr): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0)) + return emit_te_with_span(bb, topi.split, inputs[0], indices, attr.get("axis", 0)) @classmethod def _impl_v13(cls, bb, inputs, attr): @@ -684,7 +705,7 @@ def _impl_v13(cls, bb, inputs, attr): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) + return emit_te_with_span(bb, topi.split, inputs[0], indices, axis=attr.get("axis", 0)) class Slice(OnnxOpConverter): @@ -716,7 +737,9 @@ def _impl_v13(cls, bb, inputs, attr): steps = steps.data.numpy().tolist() else: steps = [1] * len(axes) - return bb.emit_te(topi.strided_slice, data, starts, ends, strides=steps, axes=axes) + return emit_te_with_span( + bb, topi.strided_slice, data, starts, ends, strides=steps, axes=axes + ) class Pad(OnnxOpConverter): @@ -744,9 +767,13 @@ def _impl_v11(cls, bb, inputs, attr): ) if pad_mode == "constant": - return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after, constant_value) + return emit_te_with_span( + bb, topi.nn.pad, inputs[0], pad_before, pad_after, constant_value + ) elif pad_mode == "reflect": - return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") + return emit_te_with_span( + bb, topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT" + ) else: # TODO(gigiblender) Support edge mode. raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) @@ -762,7 +789,7 @@ def _impl_v13(cls, bb, inputs, attr): reps = reps.data.numpy().tolist() else: raise ValueError("Dynamic reps for Tile are supported yet.") - return bb.emit_te(topi.tile, inputs[0], reps) + return emit_te_with_span(bb, topi.tile, inputs[0], reps) class Expand(OnnxOpConverter): @@ -785,7 +812,7 @@ def _impl_v13(cls, bb, inputs, attr): for i in range(shape_ndim): shape_vars.append(tvm.tir.Var("x_%d" % i, "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) - return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars))) + return bb.normalize(attach_span(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars)))) class Attention(OnnxOpConverter): @@ -846,47 +873,52 @@ def _impl_v1(cls, bb, inputs, attr): assert past is None, "past K, V state is not currently supported" assert extra_add is None, "extra add to QxK not currently supported" - split_1 = bb.normalize(relax.op.split(weight, 3, 1)) + split_1 = bb.normalize(attach_span(relax.op.split(weight, 3, 1))) # split weight and biases and do the matmuls w_Q, w_K, w_V = split_1[0], split_1[1], split_1[2] - split_2 = bb.emit_te(topi.split, bias, 3, 0) - split_2 = bb.normalize(relax.op.split(bias, 3, 0)) + split_2 = emit_te_with_span(bb, topi.split, bias, 3, 0) + split_2 = bb.normalize(attach_span(relax.op.split(bias, 3, 0))) b_Q, b_K, b_V = split_2[0], split_2[1], split_2[2] # need to merge batch dimensions since TVM matmul is 2D # TODO(@yuchen): check reverse_reshape, a hack here input_emb = bb.normalize( - relax.op.reshape( - input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2]) + attach_span( + relax.op.reshape( + input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2]) + ) ) ) - mul = bb.normalize(relax.op.matmul(input_emb, w_Q)) + mul = bb.normalize(attach_span(relax.op.matmul(input_emb, w_Q))) - Q = bb.normalize(relax.op.add(mul, b_Q)) + Q = bb.normalize(attach_span(relax.op.add(mul, b_Q))) - mul2 = bb.normalize(relax.op.matmul(input_emb, w_K)) - K = bb.normalize(relax.op.add(mul2, b_K)) + mul2 = bb.normalize(attach_span(relax.op.matmul(input_emb, w_K))) + K = bb.normalize(attach_span(relax.op.add(mul2, b_K))) - mul3 = bb.normalize(relax.op.matmul(input_emb, w_V)) - V = bb.normalize(relax.op.add(mul3, b_V)) + mul3 = bb.normalize(attach_span(relax.op.matmul(input_emb, w_V))) + V = bb.normalize(attach_span(relax.op.add(mul3, b_V))) # massage tensors in preparation for batched matmul def massage(bb, tensor): tensor = bb.normalize( - relax.op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) + attach_span(relax.op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))) ) # (batch_size, num_heads, seq_len, head_size) - tensor = bb.normalize(relax.op.permute_dims(tensor, [0, 2, 1, 3])) + tensor = bb.normalize(attach_span(relax.op.permute_dims(tensor, [0, 2, 1, 3]))) tensor_shape = [val.value for val in tensor.struct_info.shape.values] # (batch_size * num_heads, seq_len, head_size) # TODO(@yuchen): check reverse_reshape, hack here return bb.normalize( - relax.op.reshape( - tensor, (tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3]) + attach_span( + relax.op.reshape( + tensor, + (tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3]), + ) ) ) @@ -894,56 +926,68 @@ def massage(bb, tensor): K = massage(bb, K) V = massage(bb, V) - K_present = bb.normalize(relax.op.reshape(K, (batch_size, num_heads, seq_len, head_size))) - V_present = bb.normalize(relax.op.reshape(V, (batch_size, num_heads, seq_len, head_size))) - present = bb.emit_te(topi.stack, [K_present, V_present], 0) + K_present = bb.normalize( + attach_span(relax.op.reshape(K, (batch_size, num_heads, seq_len, head_size))) + ) + V_present = bb.normalize( + attach_span(relax.op.reshape(V, (batch_size, num_heads, seq_len, head_size))) + ) + present = emit_te_with_span(bb, topi.stack, [K_present, V_present], 0) - att_scores = bb.normalize(relax.op.matmul(Q, relax.op.permute_dims(K, [0, 2, 1]))) + att_scores = bb.normalize( + attach_span(relax.op.matmul(Q, attach_span(relax.op.permute_dims(K, [0, 2, 1])))) + ) score_dtype = att_scores.checked_type.dtype att_scores = bb.normalize( - relax.op.multiply( - att_scores, - relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype), + attach_span( + relax.op.multiply( + att_scores, + relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype), + ) ) ) att_scores = bb.normalize( - relax.op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) + attach_span(relax.op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))) ) # build the attention mask - att_mask = bb.normalize(relax.op.astype(mask_index, score_dtype)) - att_mask = bb.emit_te(topi.expand_dims, att_mask, 1, num_newaxis=2) - att_mask = relax.op.subtract(relax.const(1, dtype=score_dtype), att_mask) - att_mask = relax.op.multiply(att_mask, relax.const(-10000, dtype=score_dtype)) + att_mask = bb.normalize(attach_span(relax.op.astype(mask_index, score_dtype))) + att_mask = emit_te_with_span(bb, topi.expand_dims, att_mask, 1, num_newaxis=2) + att_mask = attach_span(relax.op.subtract(relax.const(1, dtype=score_dtype), att_mask)) + att_mask = attach_span(relax.op.multiply(att_mask, relax.const(-10000, dtype=score_dtype))) # apply the mask - att_scores = relax.op.add(att_scores, att_mask) + att_scores = attach_span(relax.op.add(att_scores, att_mask)) att_scores = bb.normalize( - relax.op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) + attach_span(relax.op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))) ) - att_probs = relax.op.nn.softmax(att_scores, axis=-1) + att_probs = attach_span(relax.op.nn.softmax(att_scores, axis=-1)) - output = bb.normalize(relax.op.matmul(att_probs, V)) + output = bb.normalize(attach_span(relax.op.matmul(att_probs, V))) # TODO(@yuchen): check reverse_reshape, hack here output_shape = [val.value for val in output.struct_info.shape.values] output = bb.normalize( - relax.op.reshape( - output, - ( - int(output_shape[0]) // num_heads, - num_heads, - int(output_shape[1]), - int(output_shape[2]), - ), + attach_span( + relax.op.reshape( + output, + ( + int(output_shape[0]) // num_heads, + num_heads, + int(output_shape[1]), + int(output_shape[2]), + ), + ) ) ) - output = bb.normalize(relax.op.permute_dims(output, axes=[0, 2, 1, 3])) + output = bb.normalize(attach_span(relax.op.permute_dims(output, axes=[0, 2, 1, 3]))) output_shape = [val.value for val in output.struct_info.shape.values] output = bb.normalize( - relax.op.reshape(output, (int(output_shape[0]), int(output_shape[1]), out_hidden)) + attach_span( + relax.op.reshape(output, (int(output_shape[0]), int(output_shape[1]), out_hidden)) + ) ) return relax.Tuple([output, present]) @@ -987,12 +1031,18 @@ def _impl_v18(cls, bb, inputs, attr): # Define relax implementation. if roi is not None: - roi = relax.op.concat( - [ - relax.op.strided_slice(roi, axes=[0], begin=[2], end=[ndims]), - relax.op.strided_slice(roi, axes=[0], begin=[ndims + 2], end=[2 * ndims]), - ], - axis=0, + roi = attach_span( + relax.op.concat( + [ + attach_span(relax.op.strided_slice(roi, axes=[0], begin=[2], end=[ndims])), + attach_span( + relax.op.strided_slice( + roi, axes=[0], begin=[ndims + 2], end=[2 * ndims] + ) + ), + ], + axis=0, + ) ) else: roi = [0.0] * 4 @@ -1010,7 +1060,8 @@ def _impl_v18(cls, bb, inputs, attr): sizes = sizes.data.numpy().astype("int64").tolist()[2:] # TODO(jwfromm) relax.image.resize2d runs into some issues with dynamism. - return bb.emit_te( + return emit_te_with_span( + bb, topi.image.resize2d, x, roi, @@ -1031,7 +1082,7 @@ class Einsum(OnnxOpConverter): @classmethod def _impl_v12(cls, bb, inputs, attr): equation = attr["equation"].decode("utf-8") - return bb.emit_te(topi.einsum, equation, *inputs) + return emit_te_with_span(bb, topi.einsum, equation, *inputs) class Range(OnnxOpConverter): @@ -1065,19 +1116,19 @@ def _impl_v6(cls, bb, inputs, attr): ndim = len(data.struct_info.shape) redux_axes = list(range(2, ndim)) - mean = relax.op.mean(data, axis=redux_axes, keepdims=True) - var = relax.op.variance(data, axis=redux_axes, keepdims=True) - sqrt = relax.op.sqrt(var + epsilon) - out = relax.op.divide(relax.op.subtract(data, mean), sqrt) + mean = attach_span(relax.op.mean(data, axis=redux_axes, keepdims=True)) + var = attach_span(relax.op.variance(data, axis=redux_axes, keepdims=True)) + sqrt = attach_span(relax.op.sqrt(attach_span(relax.op.add(var, epsilon)))) + out = attach_span(relax.op.divide(attach_span(relax.op.subtract(data, mean)), sqrt)) broadcast_shape = [-1] + [ 1, ] * (ndim - 2) if scale is not None: - scale = relax.op.reshape(scale, broadcast_shape) - out = relax.op.multiply(out, scale) + scale = attach_span(relax.op.reshape(scale, broadcast_shape)) + out = attach_span(relax.op.multiply(out, scale)) if B is not None: - B = relax.op.reshape(B, broadcast_shape) - out = relax.op.add(out, B) + B = attach_span(relax.op.reshape(B, broadcast_shape)) + out = attach_span(relax.op.add(out, B)) return out @@ -1093,7 +1144,9 @@ def _impl_v16(cls, bb, inputs, attr): mean = inputs[3] var = inputs[4] epsilon = attr.get("epsilon", 1e-05) - return relax.op.nn.batch_norm(data, scale, bias, mean, var, axis=1, epsilon=epsilon) + return attach_span( + relax.op.nn.batch_norm(data, scale, bias, mean, var, axis=1, epsilon=epsilon) + ) class MaxPool(OnnxOpConverter): @@ -1146,7 +1199,9 @@ def _impl_v12(cls, bb, inputs, attr): flatten_pads = [pads[0][0], pads[1][0], pads[0][1], pads[1][1]] pads = tuple(flatten_pads) - return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads, dilations, ceil_mode) + return attach_span( + relax.op.nn.max_pool2d(data, kernel_shape, strides, pads, dilations, ceil_mode) + ) @classmethod def _get_input_spatial_shape(cls, tensor): @@ -1159,7 +1214,7 @@ class GlobalAveragePool(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr): - return relax.op.nn.adaptive_avg_pool2d(inputs[0], 1) + return attach_span(relax.op.nn.adaptive_avg_pool2d(inputs[0], 1)) class Flatten(OnnxOpConverter): @@ -1170,7 +1225,7 @@ def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", 1) data_shape = [i.value for i in inputs[0].struct_info.shape] new_shape = (1, -1) if axis == 0 else (_np.prod(data_shape[0:axis]).astype("int64"), -1) - return relax.op.reshape(inputs[0], new_shape) + return attach_span(relax.op.reshape(inputs[0], new_shape)) class LayerNormalization(OnnxOpConverter): @@ -1184,7 +1239,7 @@ def _impl_v17(cls, bb, inputs, attr): axis = attr.get("axis", -1) epsilon = attr.get("epsilon", 1e-05) - output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon) + output = attach_span(relax.op.nn.layer_norm(data, scale, bias, axis, epsilon)) # Onnx layernorm has 3 outputs but only the first is used. # We construct two empty constants for this. placeholder = relax.const(0, dtype="float32") @@ -1199,7 +1254,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.max(data, axes, keepdims) + return attach_span(relax.op.max(data, axes, keepdims)) class ReduceMin(OnnxOpConverter): @@ -1210,7 +1265,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.min(data, axes, keepdims) + return attach_span(relax.op.min(data, axes, keepdims)) class ReduceSum(OnnxOpConverter): @@ -1221,7 +1276,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.sum(data, axes, keepdims) + return attach_span(relax.op.sum(data, axes, keepdims)) class ReduceMean(OnnxOpConverter): @@ -1232,7 +1287,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.mean(data, axes, keepdims) + return attach_span(relax.op.mean(data, axes, keepdims)) class ReduceProd(OnnxOpConverter): @@ -1243,7 +1298,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.prod(data, axes, keepdims) + return attach_span(relax.op.prod(data, axes, keepdims)) class ReduceLogSumExp(OnnxOpConverter): @@ -1254,12 +1309,12 @@ def _impl_v13(cls, bb, inputs, attr): x = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - max_x = relax.op.max(x, axes, True) - exp_x = relax.op.exp(x - max_x) - sum_x = relax.op.sum(exp_x, axes, True) - out_x = relax.op.log(sum_x) + max_x + max_x = attach_span(relax.op.max(x, axes, True)) + exp_x = attach_span(relax.op.exp(attach_span(relax.op.subtract(x, max_x)))) + sum_x = attach_span(relax.op.sum(exp_x, axes, True)) + out_x = attach_span(relax.op.add(attach_span(relax.op.log(sum_x)), max_x)) if not keepdims: - out_x = relax.op.squeeze(out_x, axes) + out_x = attach_span(relax.op.squeeze(out_x, axes)) return out_x @@ -1271,7 +1326,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.log(relax.op.sum(data, axes, keepdims)) + return attach_span(relax.op.log(attach_span(relax.op.sum(data, axes, keepdims)))) class ReduceSumSquare(OnnxOpConverter): @@ -1282,7 +1337,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.sum(relax.op.multiply(data, data), axes, keepdims) + return attach_span(relax.op.sum(attach_span(relax.op.multiply(data, data)), axes, keepdims)) class ReduceL1(OnnxOpConverter): @@ -1293,7 +1348,7 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.sum(relax.op.abs(data), axes, keepdims) + return attach_span(relax.op.sum(attach_span(relax.op.abs(data)), axes, keepdims)) class ReduceL2(OnnxOpConverter): @@ -1304,7 +1359,13 @@ def _impl_v13(cls, bb, inputs, attr): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims)) + return attach_span( + relax.op.sqrt( + attach_span( + relax.op.sum(attach_span(relax.op.multiply(data, data)), axes, keepdims) + ) + ) + ) class SkipLayerNormalization(OnnxOpConverter): @@ -1324,11 +1385,11 @@ def _impl_v1(cls, bb, inputs, attr): epsilon = attr.get("epsilon", 1e-12) - data = relax.op.add(data, skip) + data = attach_span(relax.op.add(data, skip)) if bias is not None: - data = relax.op.add(data, bias) + data = attach_span(relax.op.add(data, bias)) - output = relax.op.nn.layer_norm(data, gamma, beta, axes=-1, epsilon=epsilon) + output = attach_span(relax.op.nn.layer_norm(data, gamma, beta, axes=-1, epsilon=epsilon)) # Expects three outputs though only the first is used. Construct a placeholder for others. placeholder = relax.const(0, dtype="float32") @@ -1360,21 +1421,21 @@ def _impl_v1(cls, bb, inputs, attr): if pos_ids is None: pos_ids = relax.const([list(range(seq_len))] * batch_size, dtype="int64") # TODO(jwfromm) Replace with relax ops once take has better support. - word_vec = bb.emit_te(topi.take, word_emb, input_ids, 0) + word_vec = emit_te_with_span(bb, topi.take, word_emb, input_ids, 0) if segment_ids: - segment_vec = bb.emit_te(topi.take, segment_emb, segment_ids, 0) - pos_vec = bb.emit_te(topi.take, pos_emb, pos_ids, 0) + segment_vec = emit_te_with_span(bb, topi.take, segment_emb, segment_ids, 0) + pos_vec = emit_te_with_span(bb, topi.take, pos_emb, pos_ids, 0) - vec_sum = relax.op.add(word_vec, pos_vec) + vec_sum = attach_span(relax.op.add(word_vec, pos_vec)) if segment_ids: - vec_sum = relax.op.add(vec_sum, segment_vec) + vec_sum = attach_span(relax.op.add(vec_sum, segment_vec)) - ln = relax.op.nn.layer_norm(vec_sum, gamma, beta, axes=-1, epsilon=epsilon) + ln = attach_span(relax.op.nn.layer_norm(vec_sum, gamma, beta, axes=-1, epsilon=epsilon)) mask_index = relax.const(_np.zeros((batch_size,), dtype="int64")) if mask: # Caculate number of words per sentence. - mask_index = relax.op.sum(mask, axis=1) + mask_index = attach_span(relax.op.sum(mask, axis=1)) return relax.Tuple([ln, mask_index]) @@ -1386,7 +1447,7 @@ class Greater(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr): x = inputs[0] y = inputs[1] - return relax.op.greater(x, y) + return attach_span(relax.op.greater(x, y)) class Reciprocal(OnnxOpConverter): @@ -1395,7 +1456,7 @@ class Reciprocal(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): input_dtype = inputs[0].struct_info.dtype - return relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0]) + return attach_span(relax.op.divide(relax.const(1, dtype=input_dtype), inputs[0])) class OneHot(OnnxOpConverter): @@ -1413,7 +1474,7 @@ def _impl_v11(cls, bb, inputs, attr): assert isinstance(values, relax.Constant), "Only constant values currently supported." values = values.data.numpy().tolist() off_value, on_value = values - return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) + return emit_te_with_span(bb, topi.one_hot, indices, on_value, off_value, depth, axis, dtype) def _get_convert_map(): @@ -1526,15 +1587,10 @@ def __init__( self._sanitize: bool = sanitize self.bb: relax.BlockBuilder = relax.BlockBuilder() # pylint: disable=invalid-name - def from_onnx( - self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int - ) -> Tuple[IRModule, Dict[str, tvm.nd.array]]: + def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule: """Construct Relax expressions from the ONNX graph. Onnx graph is a python protobuf object. - #TODO (gigiblender): Handle model input name sanitization. This has been a problem - in the Relay importer in the past and we should be careful to avoid it here. - Parameters ---------- graph : onnx protobuf object @@ -1544,8 +1600,6 @@ def from_onnx( ------- mod : tvm.IRModule The returned relax module - params : dict - A dict of name: tvm.nd.array pairs, used as pretrained weights """ with self.bb.function("main"): with self.bb.dataflow() as df: # pylint: disable=invalid-name, unused-variable @@ -1563,7 +1617,8 @@ def from_onnx( param_list = [v for k, v in self._inputs.items() if isinstance(v, relax.Var)] output_var = self.bb.emit_output(outputs) self.bb.emit_func_output(output_var, params=param_list) - return self.bb.get() + relax_mod = self.bb.get() + return relax_mod def _parse_graph_initializers(self, graph: onnx.onnx_ml_pb2.GraphProto): """Parse network inputs to relax, aka parameters.""" @@ -1651,7 +1706,7 @@ def _check_for_unsupported_ops(self, graph: onnx.onnx_ml_pb2.GraphProto): def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): """Nodes are stored as directed acyclic graph.""" - for node in graph.node: + for node_index, node in enumerate(graph.node): op_name = node.op_type attr = self._parse_attr(node.attribute) # Create and populate input list. @@ -1667,7 +1722,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): attr["tvm_custom"]["name"] = i_name attr["tvm_custom"]["num_outputs"] = len(outputs) - op = self._convert_operator(op_name, inputs, attr, self.opset) + op = self._convert_operator(op_name, node_index, inputs, attr, self.opset) # Create struct information for the new operator. op = self.bb.normalize(op) @@ -1734,7 +1789,7 @@ def _parse_attr(self, attr_proto: onnx.onnx_ml_pb2.AttributeProto) -> Dict[str, return attrs def _convert_operator( - self, op_name: str, inputs: List[relax.Function], attrs: Dict, opset: int + self, op_name: str, node_index: int, inputs: List[relax.Function], attrs: Dict, opset: int ) -> relax.Function: """Convert ONNX operator into a Relax operator. The converter must specify conversions explicitly for incompatible name, and @@ -1744,6 +1799,8 @@ def _convert_operator( ---------- op_name : str Operator name, such as Convolution, FullyConnected + node_index : int + Index of the node in the ONNX graph. inputs : list of tvm.relax.function.Function List of inputs. attrs : dict @@ -1759,7 +1816,9 @@ def _convert_operator( if op_name in convert_map: convert_class = convert_map[op_name] op_function = convert_class.get_converter(opset) - sym = op_function(self.bb, inputs, attrs) + span = tvm.ir.Span(tvm.ir.SourceName(op_name), node_index, node_index, 0, 0) + with relax.frontend.SpanContext(span): + sym = op_function(self.bb, inputs, attrs) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index ac92114ef9cb..ef0a8cb92ca5 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -574,7 +574,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(op); } else { - call = Call(new_op, new_args, op->attrs, op->sinfo_args); + call = Call(new_op, new_args, op->attrs, op->sinfo_args, op->span); } if (!call->struct_info_.defined()) { diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index adbd05051b53..e03f04fea69f 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -39,6 +39,30 @@ rg = np.random.Generator(bg) +def from_onnx_wrapper(model: ModelProto, opset: int = None): + """ + Wrapper around the from_onnx method. Asserts that the returned Relax IRModule has + span information attached to all call nodes. + """ + + relax_mod = from_onnx(model, opset=opset) + + @relax.expr_functor.visitor + class SpanValidator(tvm.relax.PyExprVisitor): + def visit_call_(self, call: relax.Call): # pylint: disable=arguments-differ + assert call.span is not None, "Span information not available for call node {}".format( + call.op.name + ) + super().visit_call_(call) + + span_validator = SpanValidator() + for _, func in relax_mod.functions.items(): + if isinstance(func, relax.Function): + span_validator.visit_expr(func) + + return relax_mod + + def generate_random_inputs( model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None ) -> Dict[str, np.ndarray]: @@ -97,7 +121,7 @@ def check_correctness( ort_output = ort_session.run([], inputs) # Convert the onnx model into relax through the onnx importer. - tvm_model = from_onnx(model, opset=opset) + tvm_model = from_onnx_wrapper(model, opset=opset) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) # Compile the relax graph into a VM then run. @@ -131,6 +155,31 @@ def check_correctness( tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5) +def test_span_is_added(): + add_node = helper.make_node("Add", inputs=["input_1", "input_2"], outputs=["add_output"]) + div_node = helper.make_node("Div", inputs=["add_output", "input_3"], outputs=["output"]) + + graph = helper.make_graph( + [add_node, div_node], + "test", + inputs=[ + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("input_2", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("input_3", TensorProto.FLOAT, [32, 32]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [32, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="test_span") + tvm_model = from_onnx_wrapper(model) + + bindings = tvm_model["main"].body.blocks[0].bindings + assert bindings[-2].value.span.source_name.name == "Add" + assert bindings[-1].value.span.source_name.name == "Div" + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -154,7 +203,7 @@ def test_sanitize(input_names, expected_names): ) model = helper.make_model(graph, producer_name="test_sanitizer") - tvm_model = from_onnx(model) + tvm_model = from_onnx_wrapper(model) for i, param in enumerate(tvm_model["main"].params): assert param.name_hint == expected_names[i]