Skip to content

Commit

Permalink
[TUZ-157] Add span information for all ops used in the ONNX frontend (a…
Browse files Browse the repository at this point in the history
…pache#34)

This PR adds Span information to the IRModule generated by the ONXN
frontend

---------

Co-authored-by: Josh Fromm <jwfromm@octoml.ai>
  • Loading branch information
Florin Blanaru and Josh Fromm authored Mar 14, 2023
1 parent 87d008b commit c346212
Show file tree
Hide file tree
Showing 5 changed files with 367 additions and 191 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""
Frontends for constructing Relax programs, with the model importers
"""
from .common import detach_params
from .common import detach_params, SpanContext, attach_span, emit_te_with_span
70 changes: 69 additions & 1 deletion python/tvm/relax/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
# pylint: disable=invalid-name
"""Commons for Relax frontend."""
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union, Callable, Any

import tvm
from ...ir import Span, SourceName


def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]:
Expand Down Expand Up @@ -53,3 +54,70 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n
else:
detached_mod[gv] = func
return detached_mod, params_dict


def emit_te_with_span(bb, func: Callable, *args: Any, **kwargs: Any) -> tvm.relax.Var:
"""Same as block_builder.emit_te, but attaches a span to the generated call.
Uses the current span in the SpanContext.
"""

call = bb.call_te(func, *args, **kwargs)
call = attach_span(call)
return bb.emit(call)


def attach_span(op: tvm.relax.Call):
"""Attach a span to a Relax op if it doesn't already have one.
Uses the current span in the SpanContext.
Parameters
----------
op : tvm.relax.Expr
The op to attach a span to.
Returns
-------
op : tvm.relax.Expr
The op with a span attached.
"""
assert isinstance(op, tvm.relax.Call), "Expected a Call node but got: {op}".format(
op=str(type(op))
)
if op.span is None:
return tvm.relax.Call(op.op, op.args, op.attrs, op.sinfo_args, SpanContext.current())
return op


class SpanContext:
"""A context manager for setting the current Span.
Parameters
----------
span : Union[Span, str]
The span to set as the current span.
"""

__current_span = None

def __init__(self, span: Union[Span, str]):
assert isinstance(span, (Span, str)), "span must be a Span or str"
if isinstance(span, str):
span = Span(SourceName(span), 0, 0, 0, 0)
SpanContext.__current_span = span

def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
SpanContext.__current_span = None

@staticmethod
def current():
"""Get the span in the current context.
Returns
-------
span : Optional[Span]
The current span.
"""
return SpanContext.__current_span
Loading

0 comments on commit c346212

Please sign in to comment.