Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[TVMScript] Update Type Annotation Behavior of the Parser (#269)
Browse files Browse the repository at this point in the history
This commit changes the behavior of the parser to allow type annotations, as suggested by the community.
The current behavior:
- Use the more refined type/shape between user annotated and deduced type/shape.
The updated behavior:
- Always use user annotations
- Only checks if the type/shape is valid.
  • Loading branch information
Hzfengsy authored Oct 20, 2022
1 parent c6d6a06 commit 32a03f8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 53 deletions.
10 changes: 5 additions & 5 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value,
/*!
* \brief Annotate and check the type and shape of relax var.
* \param var The input var to be annotated.
* \param type The given type.
* \param shape The given shape, which can be undefined.
* \note This function will check if the type of var is compatible with the given type.
* \param anno_type The annotated type.
* \param anno_shape The annotated shape, which can be undefined.
* \note This function will check if the type of var is compatible with the annotated type.
* And we annotate to the var with more detailed type.
*/
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type,
const Optional<tvm::relax::ShapeExpr>& shape);
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
const Optional<tvm::relax::ShapeExpr>& anno_shape);

///////////////////////////// If Then Else /////////////////////////////

Expand Down
15 changes: 9 additions & 6 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,21 @@ def emit_match_shape(
############################# Type Deduce ##############################


def annotate_type_shape(var: Var, type: Type, shape: ShapeExpr) -> None:
def annotate_type_shape(var: Var, anno_type: Type, anno_shape: ShapeExpr) -> None:
"""Annotate and check the type of relax var.
Parameters
----------
var: Var
The input var to be annotated.
type: Type
The given type
shape: ShapeExpr
The given shape
anno_type: Type
The annotated type
anno_shape: ShapeExpr
The annotated shape
"""
_ffi_api.AnnotateTypeShape(var, type, shape)
_ffi_api.AnnotateTypeShape(var, anno_type, anno_shape)


def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name
Expand Down
33 changes: 12 additions & 21 deletions src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,35 +256,26 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(Emi

///////////////////////////// Type Deduce //////////////////////////////

void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type,
const Optional<tvm::relax::ShapeExpr>& shape) {
void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
const Optional<tvm::relax::ShapeExpr>& anno_shape) {
using tvm::relax::IsBaseOf;
if (!var->checked_type_.defined()) {
var->checked_type_ = type;
} else {
if (var->checked_type_.defined()) {
const Type& var_type = var->checked_type();
if (IsBaseOf(type, var_type)) {
// The var type is equal or more detailed than annotated one, do nothing.
} else if (IsBaseOf(var_type, type)) {
LOG(WARNING) << "The inferred type of var " << var->name_hint()
<< " by the block builder is more refined than the annotated one. The system "
"will refine it automatically.";
var->checked_type_ = type;
} else {
LOG(FATAL) << "TypeError: The annotated type and value type are not compatible. "
<< "The Type is expected to be " << var_type << " but got annotation: " << type;
}
CHECK(IsBaseOf(anno_type, var_type) || IsBaseOf(var_type, anno_type))
<< "TypeError: The annotated type and value type are not compatible. "
<< "The Type is expected to be " << var_type << " but got annotation: " << anno_type;
}

if (!var->shape_.defined()) {
var->shape_ = shape;
} else if (shape.defined()) {
if (var->shape_.defined() && anno_shape.defined()) {
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
tvm::relax::Expr var_shape = Downcast<tvm::relax::Expr>(var->shape_.value());
CHECK(block_builder->CanProveShapeEqual(var_shape, shape.value()))
CHECK(block_builder->CanProveShapeEqual(var_shape, anno_shape.value()))
<< " The shape of var " << var->name_hint() << " is expected to be " << var_shape
<< " but got annotation: " << shape.value();
<< " but got annotation: " << anno_shape.value();
}

var->checked_type_ = anno_type;
var->shape_ = anno_shape;
}

TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateTypeShape").set_body_typed(AnnotateTypeShape);
Expand Down
51 changes: 30 additions & 21 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,28 +427,37 @@ def foo(
o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object)
return o

m = tir.Var("m", "int64")
x = relax.Var("x", (32, m), relax.DynTensorType(2, "float32"))
y = relax.Var("y", (m,), relax.DynTensorType(1, "float32"))
r = relax.Var("r", None, relax.DynTensorType(-1, "int64"))
bb = relax.BlockBuilder()
with bb.function("foo", (x, y, r)):
z = bb.emit(R.multiply(x, y))
w = bb.emit(R.multiply(z, z))
q = bb.emit(R.add(w, w))
t = bb.emit(R.add(w, z))
sh = bb.emit(R.shape_of(t))
o = bb.emit(
relax.Call(
relax.ExternFunc("contrib.tensor_array_stack"),
[x, y],
None,
type_args=[relax.ObjectType()],
)
)
bb.emit_func_output(o)
def _check_type_shape(binding, expected_type, expected_shape):
tvm.ir.assert_structural_equal(binding.var.checked_type, expected_type)
tvm.ir.assert_structural_equal(binding.var.shape_, expected_shape)

# Cannot use block builder here because we need to check the annotated type,
# which may be inconsistent with deduced type.
assert isinstance(foo.ret_type, relax.ObjectType)
m = foo.params[0].shape[1]
bindings = foo.body.blocks[0].bindings
_check_type_shape(
bindings[0], relax.DynTensorType(ndim=2, dtype="float32"), relax.ShapeExpr([32, m])
)
_check_type_shape(bindings[1], relax.DynTensorType(dtype=""), None)
_check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), None)
_check_type_shape(bindings[3], relax.DynTensorType(dtype=""), None)
_check_type_shape(bindings[4], relax.ShapeType(), None)
_check_type_shape(bindings[5], relax.ObjectType(), None)


def test_annotate_override():
@R.function
def foo(x: R.Tensor):
y = x
# z will be treated as object type even though it's a tensor
z: R.Object = y
return z

_check(foo, bb.get()["foo"])
assert isinstance(foo.ret_type, relax.ObjectType)
y_bind, z_bind = foo.body.blocks[0].bindings
assert isinstance(y_bind.var.checked_type, relax.DynTensorType)
assert isinstance(z_bind.var.checked_type, relax.ObjectType)


def test_empty_shape():
Expand Down

0 comments on commit 32a03f8

Please sign in to comment.