Skip to content

Commit

Permalink
[Op][Debugging] Add assert operator (tlc-pack#260)
Browse files Browse the repository at this point in the history
It was brought up that Relay lacks an assert operator, so we may as well have one in Relax for debugging. One issue is that we can't name it "`assert`" because Python will treat it as a syntax error to have it as a field name for the "`relax`" module, i.e., `relax.assert` is a syntax error. Thus the op is named "`assert_op`," which is not ideal but serves its purpose.
  • Loading branch information
slyubomirsky authored and junrushao committed Jan 29, 2023
1 parent fff5148 commit 451abda
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 41 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
}
};

struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
std::string format;
TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
TVM_ATTR_FIELD(format)
.describe(
"Python-style format string to use for displaying "
"an error message if the assert fails. "
"Ignored if empty.")
.set_default("");
}
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
15 changes: 15 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ class NameTable {
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype).
*
* \param ty The input type.
* \param permit_unknown_rank If true, it will permit the input type to have unknown rank
* (ndim of -1), which will require a dynamic check.
* \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype
* (namely, void), which will require a dynamic check.
*
* \return True iff the input type is a boolean scalar type (or, depending on options, has unknown
* rank or dtype)
*/
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
bool permit_unknown_dtype = true);

} // namespace relax
} // namespace tvm

Expand Down
207 changes: 167 additions & 40 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# pylint: disable=redefined-builtin
"""The base Relax operators."""
from typing import Union, List, Optional
from typing import List, Optional, Union

import tvm
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Tuple, Call
from ..ty import DynTensorType, TupleType
from ...ir import Array
from ..expr import Call, Expr, ExternFunc, ShapeExpr, Tuple
from ..ty import DynTensorType, TupleType
from . import _ffi_api

py_print = print # pylint: disable=invalid-name


def call_tir(
func: Expr,
args: Union[Tuple, List[Expr]],
func: Union[str, Expr],
args: Union[Expr, Tuple, List[Expr]],
shape: Union[Tuple, ShapeExpr, List[int]],
dtype: Union[str, List[str]],
tir_vars: Optional[ShapeExpr] = None,
Expand All @@ -37,10 +40,10 @@ def call_tir(
Parameters
----------
func : Expr
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc or PrimFunc.
args : Union[Tuple, List[Expr]]
args : Union[Expr, Tuple, List[Expr]]
The input arguments.
shape: Union[Tuple, ShapeExpr, List[int]]
Expand All @@ -57,9 +60,15 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)

if isinstance(args, Expr):
args = Tuple((args,))

if isinstance(args, (list, tuple)):
args = Tuple(args)

Expand Down Expand Up @@ -131,54 +140,172 @@ def invoke_closure(
return _ffi_api.invoke_closure(closure, args)


def render_object(val: tvm.Object) -> str:
"""
Given a TVM Object, renders it in string form. Used for Relax printing and assertions.
Parameters
----------
val: tvm.Object
An object to render
Returns
-------
ret: str
A string representing the value, ideally human-readable
"""
if isinstance(val, tvm.runtime.ndarray.NDArray):
return str(val)
# no pretty-printer by default, so if we don't handle this,
# then we can't look inside tuples
if isinstance(val, tvm.runtime.container.ADT):
# the fields array of an ADT cannot be directly accessed in Python
# so we have to get the length and index into the fields separately
fields = ", ".join([render_object(val[i]) for i in range(len(val))])
# special case: tag = 0 is a tuple
if val.tag == 0:
return f"({fields})"
return f"ADT(tag={val.tag}, fields=[{fields}])"
return str(val)


@tvm.register_func("relax.run.print")
def relax_print(*args: List[any]) -> None:
def relax_print(format_str: str, *format_args: tvm.Object) -> None:
"""
Takes a list of values to print, formats with the given format string.
If the format string is empty, simply prints.
Since this function is called as a PackedFunc from the generated code,
we cannot have it be variadic _and_ have an optional format string attribute
except by taking in all the arguments as a single list. The last argument
should be a format string.
Call from TVM script like this:
`relax.print(value1, value2, ..., valueN, format=format_str)`
or
`relax.print(value1, value2, ..., valueN) # format_str defaults to ""`
Parameters
----------
vals: List[Object]
format_str: str
The last argument is a Python-style format string for printing the value
format_args: List[Object]
The values to print.
"""
val_strs = map(render_object, format_args)
if format_str == "":
py_print(*val_strs)
else:
py_print(format_str.format(*val_strs))


def print(values: Union[Expr, List[Expr]], format: str) -> Expr:
"""Print op to print the values
Parameters
----------
values : List[Expr]
The values to print.
format_str: str
The last argument is a Python-style format string for printing the value
The format string.
Returns
-------
result : Expr
A relax Call, which will print the value during runtime.
"""
if isinstance(values, Expr):
values = [values]
return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member


@tvm.register_func("relax.run.assert_op")
def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None:
"""
A variadic function. The first value serves as the assertion condition:
If the condition is true, then the operator does nothing.
If the condition is false, then the operator raises an assertion error.
Arguments after the first value serve as format arguments for the error message;
the last argument must be a format string for the error message (empty by default).
If the format string is the empty string, then the error message will simply include
a comma-separated list of the format arguments.
The condition argument is not included in the format string.
Parameters
----------
condition: tvm.Object
The assertion condition. Must be a boolean scalar.
# there is no way to have a keyword arg to a packed function,
# so the format string is always the last argument
format_str = args[-1]
format_str: str
The last argument is a Python-style format string for printing the value
format_args: List[tvm.Object]
Values used for formatting the string.
"""
if not isinstance(format_str, str):
raise ValueError("No valid format string given.")

def render(val: tvm.Object) -> str:
if isinstance(val, tvm.runtime.ndarray.NDArray):
return str(val)
# no pretty-printer by default, so if we don't handle this,
# then we can't look inside tuples
if isinstance(val, tvm.runtime.container.ADT):
# the fields array of an ADT cannot be directly accessed in Python
# so we have to get the length and index into the fields separately
fields = ", ".join([render(val[i]) for i in range(len(val))])
# special case: tag = 0 is a tuple
if val.tag == 0:
return f"({fields})"
return f"ADT(tag={val.tag}, fields=[{fields}])"
return str(val)
raise ValueError(
f"The format string argument to assert must be a string, given {type(format_str)})"
)

# should be guaranteed by the type system
if not isinstance(condition, tvm.runtime.ndarray.NDArray):
raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.")

# may happen if the original program had unknown shape or dtype for the tensor's type
dtype = condition.dtype
if dtype != "bool":
raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")
shape = condition.shape
if len(shape) != 0:
raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}")

val = condition.numpy()
if not val:
error_message = "Assertion Failed"
if format_args or format_str != "":
rendered = map(render_object, format_args)
if format_str != "":
error_message = format_str.format(*rendered)
else:
error_message = ", ".join(rendered)
raise AssertionError(error_message)


def assert_op(condition: Expr, format_args: Optional[List[Expr]] = None, format: str = "") -> Expr:
"""
Create a call to Relax's assert_op operation (`assert` is reserved in Python,
so the name must be distinct).
val_strs = map(render, args[:-1])
if format_str == "":
print(*val_strs)
else:
print(format_str.format(*val_strs))
Parameters
----------
condition: Expr
The assertion condition.
format_args: List[Expr]
Format arguments for the error message if the condition fails.
format_str: str
The format string for the error message.
Returns
-------
result : Expr
A Call to the Relax assert operation.
"""
if format_args is None:
format_args = []
return _ffi_api.assert_op(condition, format_args, format) # type: ignore


def shape_of(expr: Expr) -> Expr:
"""Get shape of a tensor.
Parameters
----------
expr : Expr
The input Expr.
Returns
-------
result : Expr
A relax Call, which gets the shape of the input
"""
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ class UniqueAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.PrintAttrs")
class PrintAttrs(Attrs):
"""Attributes used for the print operator"""


@tvm._ffi.register_object("relax.attrs.AssertOpAttrs")
class AssertOpAttrs(Attrs):
"""Attributes used for the assert operator"""
10 changes: 9 additions & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,14 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
}
if (call_node->op == print_op_) {
auto print_attrs = call_node->attrs.as<PrintAttrs>();
args.push_back(EmitConstantFromValue(print_attrs->format));
// format string is the first argument
args.insert(args.begin(), EmitConstantFromValue(print_attrs->format));
return;
}
if (call_node->op == assert_op_) {
auto assert_attrs = call_node->attrs.as<AssertOpAttrs>();
// format string comes before the format args
args.insert(args.begin() + 1, EmitConstantFromValue(assert_attrs->format));
return;
}
LOG(FATAL) << "Support for attributes of Op " << call_node->op
Expand Down Expand Up @@ -520,6 +527,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& unique_op_ = Op::Get("relax.unique");
const Op& print_op_ = Op::Get("relax.print");
const Op& assert_op_ = Op::Get("relax.assert_op");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
};
Expand Down
46 changes: 46 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/attrs/shape.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/op.h>

#include "op_common.h"
Expand Down Expand Up @@ -118,6 +119,51 @@ Expr MakePrint(Array<Expr> vals, std::string format_str) {

TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);

// assert_op

// can't actually name it assert or else Python will consider it a syntax error

Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) {
// Ensure that the condition argument is a boolean scalar.
// Also permitted is a tensor with unknown shape and unknown dtype
// (checked dynamically in that case). Returns void.
if (call->args.size() < 1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Assert must have at least one argument (the condition).");
}
Type arg_type = call->args[0]->checked_type();
if (!IsBoolScalarType(arg_type)) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "The argument to assert must be a boolean scalar type, but received "
<< arg_type);
}
return VoidType();
}

TVM_REGISTER_NODE_TYPE(AssertOpAttrs);

RELAY_REGISTER_OP("relax.assert_op")
.set_attrs_type<AssertOpAttrs>()
.set_num_inputs(-1)
.add_argument("vals", "Array<Expr>",
"The first value is used as the assertion condition. The others are used as "
"format arguments if there is an error.")
.set_attr<FInferType>("FInferType", InferAssertType)
.set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");

Expr MakeAssertOp(Expr condition, Array<Expr> vals, std::string format) {
auto attrs = make_object<AssertOpAttrs>();
attrs->format = format;
static const Op& op = Op::Get("relax.assert_op");
Array<Expr> args = {condition};
for (auto val : vals) {
args.push_back(val);
}
return Call(op, args, Attrs(attrs));
}

TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp);

// make_closure

RELAY_REGISTER_OP("relax.make_closure")
Expand Down
Loading

0 comments on commit 451abda

Please sign in to comment.