Skip to content

Commit

Permalink
[Relay][Frontend] Span filling common API (#13402)
Browse files Browse the repository at this point in the history
- Expose and add span attribute of Expr-derived types from C++ to Python
- Add common API of span filling
- Add test cases of span filling
- Add function to control whether to fill span via environment variable
- Modify the way of pretty-print to print span

Co-authored-by: Joey Tsai <chunit@qti.qualcomm.com>
  • Loading branch information
chunit-quic and Joey Tsai authored Dec 27, 2022
1 parent 7a38477 commit 520f2c5
Show file tree
Hide file tree
Showing 10 changed files with 750 additions and 46 deletions.
202 changes: 180 additions & 22 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,28 @@ class Constant(ExprWithOp):
----------
data : tvm.nd.NDArray
The data content of the constant expression.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, data):
self.__init_handle_by_constructor__(_ffi_api.Constant, data)
def __init__(self, data, span=None):
self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)


@tvm._ffi.register_func("relay.ConstantWithFields")
def ConstantWithFields(
constant,
data=None,
virtual_device=None,
span=None,
):
"""
Returns constant with the given properties. A None property denotes 'no change'.
Returns constant if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.ConstantWithFields(constant, data, virtual_device, span)


@tvm._ffi.register_object("relay.Tuple")
Expand All @@ -187,7 +205,7 @@ class Tuple(ExprWithOp):
The fields in the tuple.
span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, fields, span=None):
Expand All @@ -205,6 +223,16 @@ def astype(self, _):
raise TypeError("astype cannot be used on tuple")


@tvm._ffi.register_func("relay.TupleWithFields")
def TupleWithFields(tup, fields=None, virtual_device=None, span=None):
"""
Returns tuple with the given properties. A None property denotes 'no change'.
Returns tuple if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.TupleWithFields(tup, fields, virtual_device, span)


@tvm._ffi.register_object("relay.Var")
class Var(ExprWithOp):
"""A local variable in Relay.
Expand All @@ -221,10 +249,13 @@ class Var(ExprWithOp):
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation)
def __init__(self, name_hint, type_annotation=None, span=None):
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, type_annotation, span)

@property
def name_hint(self):
Expand All @@ -233,6 +264,16 @@ def name_hint(self):
return name


@tvm._ffi.register_func("relay.VarWithFields")
def VarWithFields(variable, vid=None, type_annotation=None, virtual_device=None, span=None):
"""
Returns var with the given properties. A None property denotes 'no change'.
Returns var if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.VarWithFields(variable, vid, type_annotation, virtual_device, span)


@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
Expand All @@ -256,7 +297,7 @@ class Call(ExprWithOp):
used in advanced usecase of template functions.
span: Optional[tvm.relay.Span]
Span that points to original source code
Span that points to original source code.
"""

def __init__(self, op, args, attrs=None, type_args=None, span=None):
Expand All @@ -265,6 +306,18 @@ def __init__(self, op, args, attrs=None, type_args=None, span=None):
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, type_args, span)


@tvm._ffi.register_func("relay.CallWithFields")
def CallWithFields(
call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, span=None
):
"""
Returns call with the given properties. A None property denotes 'no change'.
Returns call if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.CallWithFields(call, op, args, attrs, type_args, virtual_device, span)


@tvm._ffi.register_object("relay.Let")
class Let(ExprWithOp):
"""Let variable binding expression.
Expand All @@ -279,10 +332,23 @@ class Let(ExprWithOp):
body: tvm.relay.Expr
The body of the let binding.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body)
def __init__(self, variable, value, body, span=None):
self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, body, span)


@tvm._ffi.register_func("relay.LetWithFields")
def LetWithFields(let, variable=None, value=None, body=None, virtual_device=None, span=None):
"""
Returns let with the given properties. A None property denotes 'no change'.
Returns let if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, span)


@tvm._ffi.register_object("relay.If")
Expand All @@ -299,10 +365,25 @@ class If(ExprWithOp):
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch)
def __init__(self, cond, true_branch, false_branch, span=None):
self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span)


@tvm._ffi.register_func("relay.IfWithFields")
def IfWithFields(
if_expr, cond=None, true_branch=None, false_branch=None, virtual_device=None, span=None
):
"""
Returns if with the given properties. A None property denotes 'no change'.
Returns if if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, virtual_device, span)


@tvm._ffi.register_object("relay.TupleGetItem")
Expand All @@ -316,10 +397,25 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)
def __init__(self, tuple_value, index, span=None):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)


@tvm._ffi.register_func("relay.TupleGetItemWithFields")
def TupleGetItemWithFields(
tuple_get_item, tuple_value=None, index=None, virtual_device=None, span=None
):
"""
Returns tuple_get_item with the given properties. A None property denotes 'no change'.
Returns tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, virtual_device, span)


@tvm._ffi.register_object("relay.RefCreate")
Expand All @@ -329,10 +425,28 @@ class RefCreate(ExprWithOp):
----------
value: tvm.relay.Expr
The initial value.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, value):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
def __init__(self, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)


@tvm._ffi.register_func("relay.RefCreateWithFields")
def RefCreateWithFields(
ref_create,
value=None,
virtual_device=None,
span=None,
):
"""
Returns ref_create with the given properties. A None property denotes 'no change'.
Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, span)


@tvm._ffi.register_object("relay.RefRead")
Expand All @@ -342,10 +456,28 @@ class RefRead(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
def __init__(self, ref, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)


@tvm._ffi.register_func("relay.RefReadWithFields")
def RefReadWithFields(
ref_read,
ref=None,
virtual_device=None,
span=None,
):
"""
Returns ref_read with the given properties. A None property denotes 'no change'.
Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span)


@tvm._ffi.register_object("relay.RefWrite")
Expand All @@ -357,12 +489,32 @@ class RefWrite(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
span: Optional[tvm.relay.Span]
Span that points to original source code.
"""

def __init__(self, ref, value):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
def __init__(self, ref, value, span=None):
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, span)


@tvm._ffi.register_func("relay.RefWriteWithFields")
def RefWriteWithFields(
ref_write,
ref=None,
value=None,
virtual_device=None,
span=None,
):
"""
Returns ref_write with the given properties. A None property denotes 'no change'.
Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new
fields.
"""
return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, span)


class TempExpr(ExprWithOp):
Expand Down Expand Up @@ -433,7 +585,7 @@ def astype(self, _):
raise TypeError("astype cannot be used on tuple")


def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
def var(name_hint, type_annotation=None, shape=None, dtype="float32", span=None):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
Expand All @@ -456,6 +608,9 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
dtype: str, optional
The data type of the tensor.
span: Optional[tvm.relay.Span]
Span that points to original source code.
Examples
--------
.. code-block:: python
Expand All @@ -476,10 +631,10 @@ def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
return Var(name_hint, type_annotation, span)


def const(value, dtype=None):
def const(value, dtype=None, span=None):
"""Create a constant value.
Parameters
Expand All @@ -490,6 +645,9 @@ def const(value, dtype=None):
dtype: str, optional
The data type of the resulting constant.
span: Optional[tvm.relay.Span]
Span that points to original source code.
Note
----
When dtype is None, we use the following rule:
Expand All @@ -516,7 +674,7 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")

return Constant(value)
return Constant(value, span)


def bind(expr, binds):
Expand Down
Loading

0 comments on commit 520f2c5

Please sign in to comment.